結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー gew1fw
提出日時 2025-06-12 14:02:33
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,346 bytes
コンパイル時間 214 ms
コンパイル使用メモリ 82,644 KB
実行使用メモリ 278,104 KB
最終ジャッジ日時 2025-06-12 14:03:42
合計ジャッジ時間 8,880 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
ROOT = 3

def ntt(a, inverse=False):
    n = len(a)
    logn = (n).bit_length() - 1
    rev = [0] * n
    for i in range(n):
        rev[i] = rev[i >> 1] >> 1
        if i & 1:
            rev[i] |= n >> 1
        if i < rev[i]:
            a[i], a[rev[i]] = a[rev[i]], a[i]
    
    for s in range(1, logn + 1):
        m = 1 << s
        wm = pow(ROOT, (MOD - 1) // m, MOD)
        if inverse:
            wm = pow(wm, MOD - 2, MOD)
        for k in range(0, n, m):
            w = 1
            for j in range(m >> 1):
                t = a[k + j + (m >> 1)] * w % MOD
                u = a[k + j]
                a[k + j] = (u + t) % MOD
                a[k + j + (m >> 1)] = (u - t) % MOD
                w = w * wm % MOD
    if inverse:
        inv_n = pow(n, MOD - 2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def multiply(a, b):
    len_a = len(a)
    len_b = len(b)
    if len_a == 0 or len_b == 0:
        return []
    n = 1
    while n < len_a + len_b - 1:
        n <<= 1
    a_ntt = a + [0] * (n - len_a)
    b_ntt = b + [0] * (n - len_b)
    a_ntt = ntt(a_ntt)
    b_ntt = ntt(b_ntt)
    c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)]
    c = ntt(c_ntt, inverse=True)
    for i in range(len(c)):
        if c[i] < 0:
            c[i] += MOD
    c = c[:len_a + len_b - 1]
    return c

def main():
    import sys
    input = sys.stdin.read
    S = input().strip()
    n = len(S)
    max_fact = n
    fact = [1] * (max_fact + 1)
    for i in range(1, max_fact + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_fact + 1)
    inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
    for i in range(max_fact -1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    count = [0] * 26
    for c in S:
        count[ord(c) - ord('a')] += 1
    
    polys = []
    for k in count:
        if k == 0:
            continue
        poly = [inv_fact[m] for m in range(k + 1)]
        polys.append(poly)
    
    res = [1]
    for poly in polys:
        temp = multiply(res, poly)
        if len(temp) > n:
            temp = temp[:n+1]
        res = temp
    
    ans = 0
    for i in range(1, len(res)):
        if i > n:
            break
        ans = (ans + res[i] * fact[i]) % MOD
    print(ans % MOD)

if __name__ == "__main__":
    main()
0