結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー lam6er
提出日時 2025-04-16 16:24:57
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,293 bytes
コンパイル時間 370 ms
コンパイル使用メモリ 82,392 KB
実行使用メモリ 276,480 KB
最終ジャッジ日時 2025-04-16 16:26:44
合計ジャッジ時間 8,972 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 17 TLE * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import Counter

MOD = 998244353
ROOT = 3

def ntt(a, invert=False):
    n = len(a)
    rev = list(range(n))
    for i in range(1, n):
        rev[i] = rev[i >> 1] >> 1
        if i & 1:
            rev[i] |= n >> 1
        if i < rev[i]:
            a[i], a[rev[i]] = a[rev[i]], a[i]
    log_n = (n).bit_length() - 1
    for s in range(log_n):
        m = 1 << (s + 1)
        w_m = pow(ROOT, (MOD - 1) // m, MOD)
        if invert:
            w_m = pow(w_m, MOD - 2, MOD)
        for k in range(0, n, m):
            w = 1
            for j in range(m // 2):
                t = w * a[k + j + m // 2] % MOD
                u = a[k + j]
                a[k + j] = (u + t) % MOD
                a[k + j + m // 2] = (u - t) % MOD
                w = w * w_m % MOD
    if invert:
        inv_n = pow(n, MOD - 2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def convolve(a, b):
    len_a = len(a)
    len_b = len(b)
    len_c = len_a + len_b - 1
    n = 1
    while n < len_c:
        n <<= 1
    fa = a + [0] * (n - len_a)
    fb = b + [0] * (n - len_b)
    ntt(fa)
    ntt(fb)
    for i in range(n):
        fa[i] = fa[i] * fb[i] % MOD
    ntt(fa, invert=True)
    return fa[:len_c]

def main():
    S = input().strip()
    cnt = Counter(S)
    max_n = len(S)
    
    # Precompute factorial and inverse factorial
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n -1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    # Generate polynomials for each character
    polys = []
    for k in cnt.values():
        if k == 0:
            continue
        poly = [inv_fact[t] for t in range(k + 1)]
        polys.append(poly)
    
    if not polys:
        print(0)
        return
    
    # Convolve all polynomials
    current = [1]
    for poly in polys:
        current = convolve(current, poly)
    
    # Calculate the result
    result = 0
    for n in range(1, len(current)):
        if n > max_n:
            break
        term = current[n] * fact[n] % MOD
        result = (result + term) % MOD
    print(result)

if __name__ == "__main__":
    main()
0