結果
| 問題 |
No.1195 数え上げを愛したい(文字列編)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:24:37 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,293 bytes |
| コンパイル時間 | 392 ms |
| コンパイル使用メモリ | 81,764 KB |
| 実行使用メモリ | 276,596 KB |
| 最終ジャッジ日時 | 2025-04-16 00:26:23 |
| 合計ジャッジ時間 | 8,819 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | TLE * 1 -- * 25 |
ソースコード
from collections import Counter
MOD = 998244353
ROOT = 3
def ntt(a, invert=False):
n = len(a)
rev = list(range(n))
for i in range(1, n):
rev[i] = rev[i >> 1] >> 1
if i & 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]
log_n = (n).bit_length() - 1
for s in range(log_n):
m = 1 << (s + 1)
w_m = pow(ROOT, (MOD - 1) // m, MOD)
if invert:
w_m = pow(w_m, MOD - 2, MOD)
for k in range(0, n, m):
w = 1
for j in range(m // 2):
t = w * a[k + j + m // 2] % MOD
u = a[k + j]
a[k + j] = (u + t) % MOD
a[k + j + m // 2] = (u - t) % MOD
w = w * w_m % MOD
if invert:
inv_n = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
def convolve(a, b):
len_a = len(a)
len_b = len(b)
len_c = len_a + len_b - 1
n = 1
while n < len_c:
n <<= 1
fa = a + [0] * (n - len_a)
fb = b + [0] * (n - len_b)
ntt(fa)
ntt(fb)
for i in range(n):
fa[i] = fa[i] * fb[i] % MOD
ntt(fa, invert=True)
return fa[:len_c]
def main():
S = input().strip()
cnt = Counter(S)
max_n = len(S)
# Precompute factorial and inverse factorial
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n -1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
# Generate polynomials for each character
polys = []
for k in cnt.values():
if k == 0:
continue
poly = [inv_fact[t] for t in range(k + 1)]
polys.append(poly)
if not polys:
print(0)
return
# Convolve all polynomials
current = [1]
for poly in polys:
current = convolve(current, poly)
# Calculate the result
result = 0
for n in range(1, len(current)):
if n > max_n:
break
term = current[n] * fact[n] % MOD
result = (result + term) % MOD
print(result)
if __name__ == "__main__":
main()
lam6er