結果
問題 |
No.1856 Mex Sum 2
|
ユーザー |
![]() |
提出日時 | 2025-03-26 15:59:17 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,390 bytes |
コンパイル時間 | 197 ms |
コンパイル使用メモリ | 82,168 KB |
実行使用メモリ | 66,836 KB |
最終ジャッジ日時 | 2025-03-26 16:00:20 |
合計ジャッジ時間 | 7,678 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 26 TLE * 1 -- * 37 |
ソースコード
MOD = 998244353 n, M = map(int, input().split()) max_k = min(M + 1, n) # Precompute combinations (n choose m) comb = [[0] * (n + 1) for _ in range(n + 1)] comb[0][0] = 1 for i in range(1, n + 1): comb[i][0] = 1 for j in range(1, i + 1): comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % MOD ans = 0 for k in range(1, max_k + 1): if k > M: valid = 0 else: valid = 1 if k <= M: allowed = M - k # number of elements > k-1 and !=k (which are k+1..M) elements = k + allowed # 0..k-1 and k+1..M else: elements = M + 1 # k = M+1, so allowed elements are 0..M # Compute sum_terms for each m from 1 to n total = 0 for m in range(1, n + 1): if k <= M: if elements == 0: continue # sum_{t=0}^k (-1)^t * C(k, t) * ( (k - t) + allowed )^m sum_terms = 0 for t in range(0, k+1): sign = (-1) ** t c = comb[k][t] base = (k - t) + allowed term = pow(base, m, MOD) term = term * c % MOD if sign == -1: term = (-term) % MOD sum_terms = (sum_terms + term) % MOD else: # k == M+1, elements = M+1 (0..M) # sum_{t=0}^{k} -> sum_{t=0}^{M+1} (-1)^t * C(M+1, t) * (M+1 - t)^m # but k = M+1 could be larger than n, but since max_k = min(M+1, n), k <= n sum_terms = 0 for t in range(0, k + 1): if t > M + 1: continue sign = (-1) ** t c = comb[k][t] if t <= k else 0 base = (M + 1 - t) % MOD term = pow(base, m, MOD) term = term * c % MOD if sign == -1: term = (-term) % MOD sum_terms = (sum_terms + term) % MOD # Compute the term for this m term = comb[n][m] * sum_terms % MOD if k <= M: # other elements can be anything (including k) # so (M+1)^(n - m) term = term * pow(M + 1, n - m, MOD) % MOD else: # k == M+1, other elements can be anything (but k is not present) term = term * pow(M + 1, n - m, MOD) % MOD total = (total + term) % MOD ans = (ans + k * total) % MOD print(ans)