結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー lam6er
提出日時 2025-04-16 16:24:59
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,004 bytes
コンパイル時間 465 ms
コンパイル使用メモリ 82,132 KB
実行使用メモリ 231,952 KB
最終ジャッジ日時 2025-04-16 16:26:21
合計ジャッジ時間 8,851 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
ROOT = 3

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n).bit_length() - 1
    for s in range(log_n):
        m = 1 << (s + 1)
        mh = m >> 1
        w = pow(ROOT, (MOD - 1) // m, MOD)
        if invert:
            w = pow(w, MOD-2, MOD)
        for i in range(0, n, m):
            wk = 1
            for j in range(i, i + mh):
                x = a[j]
                y = a[j + mh] * wk % MOD
                a[j] = (x + y) % MOD
                a[j + mh] = (x - y) % MOD
                wk = wk * w % MOD
    if invert:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def convolve(a, b):
    len_a = len(a)
    len_b = len(b)
    n = 1
    while n < len_a + len_b - 1:
        n <<= 1
    a += [0] * (n - len_a)
    b += [0] * (n - len_b)
    a = ntt(a)
    b = ntt(b)
    c = [(x * y) % MOD for x, y in zip(a, b)]
    c = ntt(c, invert=True)
    del c[len_a + len_b - 1:]
    return c

def main():
    import sys
    input = sys.stdin.read
    S = input().strip()
    
    max_n = len(S)
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n -1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    cnt = [0] * 26
    for c in S:
        cnt[ord(c) - ord('a')] += 1
    
    a_old = [1]
    
    for i in range(26):
        m = cnt[i]
        if m == 0:
            continue
        b = [inv_fact[k] for k in range(m + 1)]
        a_new = convolve(a_old, b)
        a_old = a_new
    
    ans = 0
    for j in range(1, len(a_old)):
        ans = (ans + a_old[j] * fact[j]) % MOD
    print(ans)

if __name__ == "__main__":
    main()
0