結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:24:59 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,004 bytes |
コンパイル時間 | 465 ms |
コンパイル使用メモリ | 82,132 KB |
実行使用メモリ | 231,952 KB |
最終ジャッジ日時 | 2025-04-16 16:26:21 |
合計ジャッジ時間 | 8,851 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | TLE * 1 -- * 25 |
ソースコード
MOD = 998244353 ROOT = 3 def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] log_n = (n).bit_length() - 1 for s in range(log_n): m = 1 << (s + 1) mh = m >> 1 w = pow(ROOT, (MOD - 1) // m, MOD) if invert: w = pow(w, MOD-2, MOD) for i in range(0, n, m): wk = 1 for j in range(i, i + mh): x = a[j] y = a[j + mh] * wk % MOD a[j] = (x + y) % MOD a[j + mh] = (x - y) % MOD wk = wk * w % MOD if invert: inv_n = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def convolve(a, b): len_a = len(a) len_b = len(b) n = 1 while n < len_a + len_b - 1: n <<= 1 a += [0] * (n - len_a) b += [0] * (n - len_b) a = ntt(a) b = ntt(b) c = [(x * y) % MOD for x, y in zip(a, b)] c = ntt(c, invert=True) del c[len_a + len_b - 1:] return c def main(): import sys input = sys.stdin.read S = input().strip() max_n = len(S) fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n -1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD cnt = [0] * 26 for c in S: cnt[ord(c) - ord('a')] += 1 a_old = [1] for i in range(26): m = cnt[i] if m == 0: continue b = [inv_fact[k] for k in range(m + 1)] a_new = convolve(a_old, b) a_old = a_new ans = 0 for j in range(1, len(a_old)): ans = (ans + a_old[j] * fact[j]) % MOD print(ans) if __name__ == "__main__": main()