結果
| 問題 |
No.1195 数え上げを愛したい(文字列編)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:28:13 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,245 bytes |
| コンパイル時間 | 238 ms |
| コンパイル使用メモリ | 82,728 KB |
| 実行使用メモリ | 216,848 KB |
| 最終ジャッジ日時 | 2025-03-20 20:29:36 |
| 合計ジャッジ時間 | 9,274 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | TLE * 1 -- * 25 |
ソースコード
import sys
MOD = 998244353
def main():
S = sys.stdin.readline().strip()
n = len(S)
# Precompute factorial and inverse factorial
max_fact = n
fact = [1] * (max_fact + 1)
for i in range(1, max_fact +1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1]*(max_fact +1)
inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
for i in range(max_fact-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
# Count characters
cnt = [0]*26
for c in S:
cnt[ord(c)-ord('a')] += 1
# Initialize current polynomial [x^0] = 1
current = [1]
for c in cnt:
if c ==0:
continue
# Generate polynomial G(x) = sum x^t / t! for t=0..c
G = [inv_fact[t] for t in range(c+1)]
# Convolve current and G
current = convolve(current, G)
ans = 0
for k in range(1, len(current)):
ans = (ans + current[k] * fact[k]) % MOD
print(ans)
def primitive_root(m):
if m == 2:
return 1
if m == 167772161:
return 3
if m == 469762049:
return 3
if m == 754974721:
return 11
if m == 998244353:
return 3
divs = [2] + []
x = (m - 1) // 2
while x % 2 ==0:
x //=2
divs.append(2)
d =3
while d*d <=x:
if x %d ==0:
divs.append(d)
while x %d ==0:
x//=d
d +=2
if x>1:
divs.append(x)
g =2
while True:
ok = True
for d in divs:
if pow(g, (m-1)//d, m) ==1:
ok = False
break
if ok:
return g
g +=1
def ntt(a, invert=False):
root = primitive_root(MOD)
n = len(a)
j =0
for i in range(1, n):
rev = n >>1
while j >= rev:
j -= rev
rev >>=1
j += rev
if i <j:
a[i], a[j] = a[j], a[i]
log_n = (n).bit_length() -1
roots = [1]*(log_n+1)
roots[log_n] = pow(root, (MOD-1)//n, MOD) if not invert else pow(root, (MOD-1) - (MOD-1)//n, MOD)
for i in range(log_n-1, -1, -1):
roots[i] = roots[i+1] * roots[i+1] % MOD
for i in range(log_n):
m = 1 <<i
for j in range(0, n, m <<1):
w =1
for k in range(m):
a[j +k +m] = a[j +k +m] * w % MOD
tmp = a[j +k] - a[j +k +m]
if tmp <0:
tmp += MOD
a[j +k] = (a[j +k] + a[j +k +m]) % MOD
a[j +k +m] = tmp
w = w * roots[i +1] % MOD
if invert:
inv_n = pow(n, MOD-2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
def convolve(a, b):
len_a = len(a)
len_b = len(b)
max_len = len_a + len_b -1
n =1
while n < max_len:
n <<=1
a += [0]*(n - len_a)
b += [0]*(n - len_b)
a = ntt(a)
b = ntt(b)
c = [ (x*y) % MOD for x,y in zip(a,b)]
c = ntt(c, invert=True)
# Truncate to the required length (original a and b convolution length)
c = c[:len_a + len_b -1]
# Take mod and remove trailing zeros?
return [x % MOD for x in c]
if __name__ == '__main__':
main()
lam6er