結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー lam6er
提出日時 2025-03-20 20:28:13
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,245 bytes
コンパイル時間 238 ms
コンパイル使用メモリ 82,728 KB
実行使用メモリ 216,848 KB
最終ジャッジ日時 2025-03-20 20:29:36
合計ジャッジ時間 9,274 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    S = sys.stdin.readline().strip()
    n = len(S)
    
    # Precompute factorial and inverse factorial
    max_fact = n
    fact = [1] * (max_fact + 1)
    for i in range(1, max_fact +1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1]*(max_fact +1)
    inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
    for i in range(max_fact-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    # Count characters
    cnt = [0]*26
    for c in S:
        cnt[ord(c)-ord('a')] += 1
    
    # Initialize current polynomial [x^0] = 1
    current = [1]
    
    for c in cnt:
        if c ==0:
            continue
        
        # Generate polynomial G(x) = sum x^t / t! for t=0..c
        G = [inv_fact[t] for t in range(c+1)]
        
        # Convolve current and G
        current = convolve(current, G)
    
    ans = 0
    for k in range(1, len(current)):
        ans = (ans + current[k] * fact[k]) % MOD
    print(ans)
    
def primitive_root(m):
    if m == 2:
        return 1
    if m == 167772161:
        return 3
    if m == 469762049:
        return 3
    if m == 754974721:
        return 11
    if m == 998244353:
        return 3
    divs = [2] + []
    x = (m - 1) // 2
    while x % 2 ==0:
        x //=2
        divs.append(2)
    d =3
    while d*d <=x:
        if x %d ==0:
            divs.append(d)
            while x %d ==0:
                x//=d
        d +=2
    if x>1:
        divs.append(x)
    g =2
    while True:
        ok = True
        for d in divs:
            if pow(g, (m-1)//d, m) ==1:
                ok = False
                break
        if ok:
            return g
        g +=1

def ntt(a, invert=False):
    root = primitive_root(MOD)
    n = len(a)
    j =0
    for i in range(1, n):
        rev = n >>1
        while j >= rev:
            j -= rev
            rev >>=1
        j += rev
        if i <j:
            a[i], a[j] = a[j], a[i]
    log_n = (n).bit_length() -1
    roots = [1]*(log_n+1)
    roots[log_n] = pow(root, (MOD-1)//n, MOD) if not invert else pow(root, (MOD-1) - (MOD-1)//n, MOD)
    for i in range(log_n-1, -1, -1):
        roots[i] = roots[i+1] * roots[i+1] % MOD
    for i in range(log_n):
        m = 1 <<i
        for j in range(0, n, m <<1):
            w =1
            for k in range(m):
                a[j +k +m] = a[j +k +m] * w % MOD
                tmp = a[j +k] - a[j +k +m]
                if tmp <0:
                    tmp += MOD
                a[j +k] = (a[j +k] + a[j +k +m]) % MOD
                a[j +k +m] = tmp
                w = w * roots[i +1] % MOD
    if invert:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def convolve(a, b):
    len_a = len(a)
    len_b = len(b)
    max_len = len_a + len_b -1
    n =1
    while n < max_len:
        n <<=1
    a += [0]*(n - len_a)
    b += [0]*(n - len_b)
    a = ntt(a)
    b = ntt(b)
    c = [ (x*y) % MOD for x,y in zip(a,b)]
    c = ntt(c, invert=True)
    # Truncate to the required length (original a and b convolution length)
    c = c[:len_a + len_b -1]
    # Take mod and remove trailing zeros?
    return [x % MOD for x in c]

if __name__ == '__main__':
    main()
0