結果
問題 |
No.1195 数え上げを愛したい(文字列編)
|
ユーザー |
![]() |
提出日時 | 2025-03-20 20:28:13 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,245 bytes |
コンパイル時間 | 238 ms |
コンパイル使用メモリ | 82,728 KB |
実行使用メモリ | 216,848 KB |
最終ジャッジ日時 | 2025-03-20 20:29:36 |
合計ジャッジ時間 | 9,274 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | TLE * 1 -- * 25 |
ソースコード
import sys MOD = 998244353 def main(): S = sys.stdin.readline().strip() n = len(S) # Precompute factorial and inverse factorial max_fact = n fact = [1] * (max_fact + 1) for i in range(1, max_fact +1): fact[i] = fact[i-1] * i % MOD inv_fact = [1]*(max_fact +1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD # Count characters cnt = [0]*26 for c in S: cnt[ord(c)-ord('a')] += 1 # Initialize current polynomial [x^0] = 1 current = [1] for c in cnt: if c ==0: continue # Generate polynomial G(x) = sum x^t / t! for t=0..c G = [inv_fact[t] for t in range(c+1)] # Convolve current and G current = convolve(current, G) ans = 0 for k in range(1, len(current)): ans = (ans + current[k] * fact[k]) % MOD print(ans) def primitive_root(m): if m == 2: return 1 if m == 167772161: return 3 if m == 469762049: return 3 if m == 754974721: return 11 if m == 998244353: return 3 divs = [2] + [] x = (m - 1) // 2 while x % 2 ==0: x //=2 divs.append(2) d =3 while d*d <=x: if x %d ==0: divs.append(d) while x %d ==0: x//=d d +=2 if x>1: divs.append(x) g =2 while True: ok = True for d in divs: if pow(g, (m-1)//d, m) ==1: ok = False break if ok: return g g +=1 def ntt(a, invert=False): root = primitive_root(MOD) n = len(a) j =0 for i in range(1, n): rev = n >>1 while j >= rev: j -= rev rev >>=1 j += rev if i <j: a[i], a[j] = a[j], a[i] log_n = (n).bit_length() -1 roots = [1]*(log_n+1) roots[log_n] = pow(root, (MOD-1)//n, MOD) if not invert else pow(root, (MOD-1) - (MOD-1)//n, MOD) for i in range(log_n-1, -1, -1): roots[i] = roots[i+1] * roots[i+1] % MOD for i in range(log_n): m = 1 <<i for j in range(0, n, m <<1): w =1 for k in range(m): a[j +k +m] = a[j +k +m] * w % MOD tmp = a[j +k] - a[j +k +m] if tmp <0: tmp += MOD a[j +k] = (a[j +k] + a[j +k +m]) % MOD a[j +k +m] = tmp w = w * roots[i +1] % MOD if invert: inv_n = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def convolve(a, b): len_a = len(a) len_b = len(b) max_len = len_a + len_b -1 n =1 while n < max_len: n <<=1 a += [0]*(n - len_a) b += [0]*(n - len_b) a = ntt(a) b = ntt(b) c = [ (x*y) % MOD for x,y in zip(a,b)] c = ntt(c, invert=True) # Truncate to the required length (original a and b convolution length) c = c[:len_a + len_b -1] # Take mod and remove trailing zeros? return [x % MOD for x in c] if __name__ == '__main__': main()