結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-03-31 17:56:16 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,746 ms / 3,000 ms |
コード長 | 3,220 bytes |
コンパイル時間 | 361 ms |
コンパイル使用メモリ | 82,336 KB |
実行使用メモリ | 261,316 KB |
最終ジャッジ日時 | 2025-03-31 17:57:59 |
合計ジャッジ時間 | 27,157 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 26 |
ソースコード
MOD = 998244353 def main(): import sys from sys import stdin S = stdin.read().strip() cnt = [0] * 26 for c in S: cnt[ord(c) - ord('a')] += 1 max_n = sum(cnt) if max_n == 0: print(0) return # Precompute factorial and inverse factorial fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD # Function for NTT-based polynomial multiplication def ntt(a, invert=False): n = len(a) rev = [0] * n for i in range(n): rev[i] = rev[i >> 1] >> 1 if i & 1: rev[i] |= n >> 1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] root = pow(3, (MOD-1)//n, MOD) if not invert else pow(3, MOD-1 - (MOD-1)//n, MOD) roots = [1] * (n//2) for i in range(1, len(roots)): roots[i] = roots[i-1] * root % MOD current_length = 2 while current_length <= n: half = current_length >> 1 step = n // current_length for i in range(0, n, current_length): jk = i for j in range(half): j1 = jk + j j2 = j1 + half u = a[j1] v = a[j2] * roots[j * step] % MOD a[j1] = (u + v) % MOD a[j2] = (u - v) % MOD current_length <<= 1 if invert: inv_n = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def multiply(a, b): len_a = len(a) len_b = len(b) n = 1 while n < len_a + len_b -1: n <<=1 a_padded = a + [0]*(n - len_a) b_padded = b + [0]*(n - len_b) a_padded = ntt(a_padded) b_padded = ntt(b_padded) c_padded = [(x*y)%MOD for x, y in zip(a_padded, b_padded)] c = ntt(c_padded, invert=True) c = [x % MOD for x in c] return c[:len_a + len_b -1] # Collect the polynomials for each character polys = [] for c in range(26): m = cnt[c] if m ==0: continue # Generate the polynomial: sum_{k=0}^m (x^k * inv_fact[k]) poly = [0]*(m+1) for k in range(m+1): poly[k] = inv_fact[k] polys.append(poly) if not polys: print(0) return # Multiply all polynomials using a divide and conquer approach import heapq from collections import deque q = deque() for p in polys: q.append(p) # Reduce until one poly remains while len(q) >1: a = q.popleft() b = q.popleft() c = multiply(a, b) q.append(c) final_poly = q[0] # Compute the answer ans = 0 for i in range(len(final_poly)): if i > max_n: break term = final_poly[i] * fact[i] % MOD ans = (ans + term) % MOD ans = (ans -1) % MOD print(ans) if __name__ == "__main__": main()