結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-06-12 14:02:33 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,346 bytes |
コンパイル時間 | 214 ms |
コンパイル使用メモリ | 82,644 KB |
実行使用メモリ | 278,104 KB |
最終ジャッジ日時 | 2025-06-12 14:03:42 |
合計ジャッジ時間 | 8,880 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | TLE * 1 -- * 25 |
ソースコード
MOD = 998244353 ROOT = 3 def ntt(a, inverse=False): n = len(a) logn = (n).bit_length() - 1 rev = [0] * n for i in range(n): rev[i] = rev[i >> 1] >> 1 if i & 1: rev[i] |= n >> 1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] for s in range(1, logn + 1): m = 1 << s wm = pow(ROOT, (MOD - 1) // m, MOD) if inverse: wm = pow(wm, MOD - 2, MOD) for k in range(0, n, m): w = 1 for j in range(m >> 1): t = a[k + j + (m >> 1)] * w % MOD u = a[k + j] a[k + j] = (u + t) % MOD a[k + j + (m >> 1)] = (u - t) % MOD w = w * wm % MOD if inverse: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def multiply(a, b): len_a = len(a) len_b = len(b) if len_a == 0 or len_b == 0: return [] n = 1 while n < len_a + len_b - 1: n <<= 1 a_ntt = a + [0] * (n - len_a) b_ntt = b + [0] * (n - len_b) a_ntt = ntt(a_ntt) b_ntt = ntt(b_ntt) c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)] c = ntt(c_ntt, inverse=True) for i in range(len(c)): if c[i] < 0: c[i] += MOD c = c[:len_a + len_b - 1] return c def main(): import sys input = sys.stdin.read S = input().strip() n = len(S) max_fact = n fact = [1] * (max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_fact + 1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact -1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD count = [0] * 26 for c in S: count[ord(c) - ord('a')] += 1 polys = [] for k in count: if k == 0: continue poly = [inv_fact[m] for m in range(k + 1)] polys.append(poly) res = [1] for poly in polys: temp = multiply(res, poly) if len(temp) > n: temp = temp[:n+1] res = temp ans = 0 for i in range(1, len(res)): if i > n: break ans = (ans + res[i] * fact[i]) % MOD print(ans % MOD) if __name__ == "__main__": main()