結果
| 問題 | 
                            No.1195 数え上げを愛したい(文字列編)
                             | 
                    
| コンテスト | |
| ユーザー | 
                             lam6er
                         | 
                    
| 提出日時 | 2025-03-31 17:56:16 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 1,746 ms / 3,000 ms | 
| コード長 | 3,220 bytes | 
| コンパイル時間 | 361 ms | 
| コンパイル使用メモリ | 82,336 KB | 
| 実行使用メモリ | 261,316 KB | 
| 最終ジャッジ日時 | 2025-03-31 17:57:59 | 
| 合計ジャッジ時間 | 27,157 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge4 / judge3 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| other | AC * 26 | 
ソースコード
MOD = 998244353
def main():
    import sys
    from sys import stdin
    S = stdin.read().strip()
    cnt = [0] * 26
    for c in S:
        cnt[ord(c) - ord('a')] += 1
    max_n = sum(cnt)
    if max_n == 0:
        print(0)
        return
    # Precompute factorial and inverse factorial
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    # Function for NTT-based polynomial multiplication
    def ntt(a, invert=False):
        n = len(a)
        rev = [0] * n
        for i in range(n):
            rev[i] = rev[i >> 1] >> 1
            if i & 1:
                rev[i] |= n >> 1
            if i < rev[i]:
                a[i], a[rev[i]] = a[rev[i]], a[i]
        root = pow(3, (MOD-1)//n, MOD) if not invert else pow(3, MOD-1 - (MOD-1)//n, MOD)
        roots = [1] * (n//2)
        for i in range(1, len(roots)):
            roots[i] = roots[i-1] * root % MOD
        current_length = 2
        while current_length <= n:
            half = current_length >> 1
            step = n // current_length
            for i in range(0, n, current_length):
                jk = i
                for j in range(half):
                    j1 = jk + j
                    j2 = j1 + half
                    u = a[j1]
                    v = a[j2] * roots[j * step] % MOD
                    a[j1] = (u + v) % MOD
                    a[j2] = (u - v) % MOD
            current_length <<= 1
        if invert:
            inv_n = pow(n, MOD-2, MOD)
            for i in range(n):
                a[i] = a[i] * inv_n % MOD
        return a
    def multiply(a, b):
        len_a = len(a)
        len_b = len(b)
        n = 1
        while n < len_a + len_b -1:
            n <<=1
        a_padded = a + [0]*(n - len_a)
        b_padded = b + [0]*(n - len_b)
        a_padded = ntt(a_padded)
        b_padded = ntt(b_padded)
        c_padded = [(x*y)%MOD for x, y in zip(a_padded, b_padded)]
        c = ntt(c_padded, invert=True)
        c = [x % MOD for x in c]
        return c[:len_a + len_b -1]
    # Collect the polynomials for each character
    polys = []
    for c in range(26):
        m = cnt[c]
        if m ==0:
            continue
        # Generate the polynomial: sum_{k=0}^m (x^k * inv_fact[k])
        poly = [0]*(m+1)
        for k in range(m+1):
            poly[k] = inv_fact[k]
        polys.append(poly)
    if not polys:
        print(0)
        return
    # Multiply all polynomials using a divide and conquer approach
    import heapq
    from collections import deque
    q = deque()
    for p in polys:
        q.append(p)
    # Reduce until one poly remains
    while len(q) >1:
        a = q.popleft()
        b = q.popleft()
        c = multiply(a, b)
        q.append(c)
    final_poly = q[0]
    # Compute the answer
    ans = 0
    for i in range(len(final_poly)):
        if i > max_n:
            break
        term = final_poly[i] * fact[i] % MOD
        ans = (ans + term) % MOD
    ans = (ans -1) % MOD
    print(ans)
if __name__ == "__main__":
    main()
            
            
            
        
            
lam6er