結果
| 問題 |
No.1195 数え上げを愛したい(文字列編)
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 14:02:33 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,346 bytes |
| コンパイル時間 | 214 ms |
| コンパイル使用メモリ | 82,644 KB |
| 実行使用メモリ | 278,104 KB |
| 最終ジャッジ日時 | 2025-06-12 14:03:42 |
| 合計ジャッジ時間 | 8,880 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | TLE * 1 -- * 25 |
ソースコード
MOD = 998244353
ROOT = 3
def ntt(a, inverse=False):
n = len(a)
logn = (n).bit_length() - 1
rev = [0] * n
for i in range(n):
rev[i] = rev[i >> 1] >> 1
if i & 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]
for s in range(1, logn + 1):
m = 1 << s
wm = pow(ROOT, (MOD - 1) // m, MOD)
if inverse:
wm = pow(wm, MOD - 2, MOD)
for k in range(0, n, m):
w = 1
for j in range(m >> 1):
t = a[k + j + (m >> 1)] * w % MOD
u = a[k + j]
a[k + j] = (u + t) % MOD
a[k + j + (m >> 1)] = (u - t) % MOD
w = w * wm % MOD
if inverse:
inv_n = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
def multiply(a, b):
len_a = len(a)
len_b = len(b)
if len_a == 0 or len_b == 0:
return []
n = 1
while n < len_a + len_b - 1:
n <<= 1
a_ntt = a + [0] * (n - len_a)
b_ntt = b + [0] * (n - len_b)
a_ntt = ntt(a_ntt)
b_ntt = ntt(b_ntt)
c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)]
c = ntt(c_ntt, inverse=True)
for i in range(len(c)):
if c[i] < 0:
c[i] += MOD
c = c[:len_a + len_b - 1]
return c
def main():
import sys
input = sys.stdin.read
S = input().strip()
n = len(S)
max_fact = n
fact = [1] * (max_fact + 1)
for i in range(1, max_fact + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_fact + 1)
inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
for i in range(max_fact -1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
count = [0] * 26
for c in S:
count[ord(c) - ord('a')] += 1
polys = []
for k in count:
if k == 0:
continue
poly = [inv_fact[m] for m in range(k + 1)]
polys.append(poly)
res = [1]
for poly in polys:
temp = multiply(res, poly)
if len(temp) > n:
temp = temp[:n+1]
res = temp
ans = 0
for i in range(1, len(res)):
if i > n:
break
ans = (ans + res[i] * fact[i]) % MOD
print(ans % MOD)
if __name__ == "__main__":
main()
gew1fw