結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-04-16 00:24:37 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,293 bytes |
コンパイル時間 | 392 ms |
コンパイル使用メモリ | 81,764 KB |
実行使用メモリ | 276,596 KB |
最終ジャッジ日時 | 2025-04-16 00:26:23 |
合計ジャッジ時間 | 8,819 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | TLE * 1 -- * 25 |
ソースコード
from collections import Counter MOD = 998244353 ROOT = 3 def ntt(a, invert=False): n = len(a) rev = list(range(n)) for i in range(1, n): rev[i] = rev[i >> 1] >> 1 if i & 1: rev[i] |= n >> 1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] log_n = (n).bit_length() - 1 for s in range(log_n): m = 1 << (s + 1) w_m = pow(ROOT, (MOD - 1) // m, MOD) if invert: w_m = pow(w_m, MOD - 2, MOD) for k in range(0, n, m): w = 1 for j in range(m // 2): t = w * a[k + j + m // 2] % MOD u = a[k + j] a[k + j] = (u + t) % MOD a[k + j + m // 2] = (u - t) % MOD w = w * w_m % MOD if invert: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def convolve(a, b): len_a = len(a) len_b = len(b) len_c = len_a + len_b - 1 n = 1 while n < len_c: n <<= 1 fa = a + [0] * (n - len_a) fb = b + [0] * (n - len_b) ntt(fa) ntt(fb) for i in range(n): fa[i] = fa[i] * fb[i] % MOD ntt(fa, invert=True) return fa[:len_c] def main(): S = input().strip() cnt = Counter(S) max_n = len(S) # Precompute factorial and inverse factorial fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n -1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD # Generate polynomials for each character polys = [] for k in cnt.values(): if k == 0: continue poly = [inv_fact[t] for t in range(k + 1)] polys.append(poly) if not polys: print(0) return # Convolve all polynomials current = [1] for poly in polys: current = convolve(current, poly) # Calculate the result result = 0 for n in range(1, len(current)): if n > max_n: break term = current[n] * fact[n] % MOD result = (result + term) % MOD print(result) if __name__ == "__main__": main()