結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:24:57 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,107 bytes |
コンパイル時間 | 411 ms |
コンパイル使用メモリ | 81,872 KB |
実行使用メモリ | 275,588 KB |
最終ジャッジ日時 | 2025-04-16 16:26:42 |
合計ジャッジ時間 | 54,353 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 17 TLE * 9 |
ソースコード
MOD = 998244353 ROOT = 3 def ntt(a, inverse=False): n = len(a) logn = (n - 1).bit_length() rev = [0] * n for i in range(n): rev[i] = rev[i >> 1] >> 1 if i & 1: rev[i] |= n >> 1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] for m in range(1, logn + 1): m_h = 1 << (m - 1) w_m = pow(ROOT, (MOD - 1) // (1 << m), MOD) if inverse: w_m = pow(w_m, MOD - 2, MOD) for i in range(0, n, 1 << m): w = 1 for j in range(i, i + m_h): x = a[j] y = a[j + m_h] * w % MOD a[j] = (x + y) % MOD a[j + m_h] = (x - y) % MOD w = w * w_m % MOD if inverse: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def convolve(a, b): len_a = len(a) len_b = len(b) len_c = len_a + len_b - 1 n = 1 << (len_c.bit_length()) if len_c != 0 else 1 a = a + [0] * (n - len_a) b = b + [0] * (n - len_b) a = ntt(a) b = ntt(b) c = [a[i] * b[i] % MOD for i in range(n)] c = ntt(c, inverse=True) return c[:len_c] def main(): import sys from collections import Counter s = sys.stdin.readline().strip() cnt = Counter(s) chars = [v for v in cnt.values() if v > 0] max_n = 3 * 10**5 # Precompute factorial and inverse factorial fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n - 1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD dp = [1] for a in chars: g = [inv_fact[k] for k in range(a + 1)] new_dp = convolve(dp, g) new_dp = new_dp[:max_n + 1] dp = new_dp ans = 0 for n in range(1, len(dp)): if n > max_n: break term = fact[n] * dp[n] % MOD ans = (ans + term) % MOD print(ans) if __name__ == "__main__": main()