結果
| 問題 |
No.986 Present
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-26 15:52:09 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 343 ms / 2,000 ms |
| コード長 | 2,094 bytes |
| コンパイル時間 | 266 ms |
| コンパイル使用メモリ | 82,764 KB |
| 実行使用メモリ | 70,648 KB |
| 最終ジャッジ日時 | 2025-03-26 15:52:52 |
| 合計ジャッジ時間 | 7,189 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 30 |
ソースコード
MOD = 998244353
def main():
import sys
input = sys.stdin.read().split()
N = int(input[0])
M = int(input[1])
d = min(N, M)
A = pow(2, d, MOD)
max_fact = max(N, M) if N <= M else 0
fact = [1] * (max_fact + 1)
for i in range(1, max_fact + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_fact + 1)
if max_fact >= 0:
inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
for i in range(max_fact - 1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
if N <= M:
ans2 = 1
for i in range(N):
term = (pow(2, M, MOD) - pow(2, i, MOD)) % MOD
ans2 = ans2 * term % MOD
ans2 = ans2 * inv_fact[N] % MOD
else:
sum_terms = 0
for k in range(M + 1):
exponent = M - k
a_mod = pow(2, exponent, MOD)
if a_mod < N:
continue
numer = 1
for i in range(N):
numer = numer * (a_mod - i) % MOD
comb = numer * inv_fact[N] % MOD
comb_m_k = fact[M] * inv_fact[k] % MOD * inv_fact[M - k] % MOD
term = comb * comb_m_k % MOD
if k % 2 == 1:
term = (-term) % MOD
sum_terms = (sum_terms + term) % MOD
ans2 = sum_terms % MOD
if N <= M:
max_pow = M + 1
pow2 = [1] * (max_pow + 2)
for i in range(1, max_pow + 2):
pow2[i] = pow2[i-1] * 2 % MOD
inv_pow2_minus_1 = [0] * (max_pow + 2)
for i in range(1, max_pow + 2):
denom = (pow2[i] - 1) % MOD
inv_pow2_minus_1[i] = pow(denom, MOD-2, MOD) if denom != 0 else 0
ans3 = 1
for i in range(N):
numerator = (pow2[M - i] - 1) % MOD
denominator = (pow2[i+1] - 1) % MOD
inv_den = inv_pow2_minus_1[i+1]
ans3 = ans3 * numerator % MOD
ans3 = ans3 * inv_den % MOD
else:
ans3 = 1
print(f"{A} {ans2} {ans3}")
if __name__ == "__main__":
main()
lam6er