結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー gew1fw
提出日時 2025-06-12 15:54:18
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,094 bytes
コンパイル時間 319 ms
コンパイル使用メモリ 82,944 KB
実行使用メモリ 243,068 KB
最終ジャッジ日時 2025-06-12 15:54:27
合計ジャッジ時間 8,705 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
MAX_N = 3 * 10**5 + 10

# Precompute factorials and inverse factorials modulo MOD
fact = [1] * (MAX_N)
for i in range(1, MAX_N):
    fact[i] = fact[i-1] * i % MOD

inv_fact = [1] * (MAX_N)
inv_fact[MAX_N-1] = pow(fact[MAX_N-1], MOD-2, MOD)
for i in range(MAX_N-2, -1, -1):
    inv_fact[i] = inv_fact[i+1] * (i+1) % MOD

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n-1).bit_length()
    for s in range(1, log_n+1):
        m = 1 << s
        wm = pow(3, (MOD-1)//m, MOD) if not invert else pow(3, MOD-1 - (MOD-1)//m, MOD)
        for k in range(0, n, m):
            w = 1
            for j in range(m//2):
                t = a[k+j+m//2] * w % MOD
                u = a[k+j]
                a[k+j] = (u + t) % MOD
                a[k+j+m//2] = (u - t) % MOD
                w = w * wm % MOD
    if invert:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def multiply(a, b):
    size = 1
    while size < len(a) + len(b) - 1:
        size <<= 1
    a += [0] * (size - len(a))
    b += [0] * (size - len(b))
    a = ntt(a)
    b = ntt(b)
    c = [a[i] * b[i] % MOD for i in range(size)]
    c = ntt(c, invert=True)
    return c[:len(a)+len(b)-1]

def main():
    import sys
    S = sys.stdin.read().strip()
    from collections import Counter
    cnt = Counter(S)
    freq = [0] * 26
    for c, v in cnt.items():
        idx = ord(c) - ord('a')
        freq[idx] = v

    F = [1]
    for c in range(26):
        f_c = freq[c]
        if f_c == 0:
            continue
        P = [0] * (f_c + 1)
        for k in range(f_c + 1):
            P[k] = inv_fact[k]
        F = multiply(F, P)
        F = F[:len(F)]
    N = sum(freq)
    res = 0
    for n in range(1, len(F)):
        if n > N:
            break
        res = (res + F[n] * fact[n]) % MOD
    print(res % MOD)

if __name__ == "__main__":
    main()
0