結果
| 問題 |
No.1195 数え上げを愛したい(文字列編)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:56:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,746 ms / 3,000 ms |
| コード長 | 3,220 bytes |
| コンパイル時間 | 361 ms |
| コンパイル使用メモリ | 82,336 KB |
| 実行使用メモリ | 261,316 KB |
| 最終ジャッジ日時 | 2025-03-31 17:57:59 |
| 合計ジャッジ時間 | 27,157 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 26 |
ソースコード
MOD = 998244353
def main():
import sys
from sys import stdin
S = stdin.read().strip()
cnt = [0] * 26
for c in S:
cnt[ord(c) - ord('a')] += 1
max_n = sum(cnt)
if max_n == 0:
print(0)
return
# Precompute factorial and inverse factorial
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
# Function for NTT-based polynomial multiplication
def ntt(a, invert=False):
n = len(a)
rev = [0] * n
for i in range(n):
rev[i] = rev[i >> 1] >> 1
if i & 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]
root = pow(3, (MOD-1)//n, MOD) if not invert else pow(3, MOD-1 - (MOD-1)//n, MOD)
roots = [1] * (n//2)
for i in range(1, len(roots)):
roots[i] = roots[i-1] * root % MOD
current_length = 2
while current_length <= n:
half = current_length >> 1
step = n // current_length
for i in range(0, n, current_length):
jk = i
for j in range(half):
j1 = jk + j
j2 = j1 + half
u = a[j1]
v = a[j2] * roots[j * step] % MOD
a[j1] = (u + v) % MOD
a[j2] = (u - v) % MOD
current_length <<= 1
if invert:
inv_n = pow(n, MOD-2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
def multiply(a, b):
len_a = len(a)
len_b = len(b)
n = 1
while n < len_a + len_b -1:
n <<=1
a_padded = a + [0]*(n - len_a)
b_padded = b + [0]*(n - len_b)
a_padded = ntt(a_padded)
b_padded = ntt(b_padded)
c_padded = [(x*y)%MOD for x, y in zip(a_padded, b_padded)]
c = ntt(c_padded, invert=True)
c = [x % MOD for x in c]
return c[:len_a + len_b -1]
# Collect the polynomials for each character
polys = []
for c in range(26):
m = cnt[c]
if m ==0:
continue
# Generate the polynomial: sum_{k=0}^m (x^k * inv_fact[k])
poly = [0]*(m+1)
for k in range(m+1):
poly[k] = inv_fact[k]
polys.append(poly)
if not polys:
print(0)
return
# Multiply all polynomials using a divide and conquer approach
import heapq
from collections import deque
q = deque()
for p in polys:
q.append(p)
# Reduce until one poly remains
while len(q) >1:
a = q.popleft()
b = q.popleft()
c = multiply(a, b)
q.append(c)
final_poly = q[0]
# Compute the answer
ans = 0
for i in range(len(final_poly)):
if i > max_n:
break
term = final_poly[i] * fact[i] % MOD
ans = (ans + term) % MOD
ans = (ans -1) % MOD
print(ans)
if __name__ == "__main__":
main()
lam6er