結果

問題 No.1504 ヌメロニム
ユーザー lam6er
提出日時 2025-03-31 18:01:29
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,004 bytes
コンパイル時間 188 ms
コンパイル使用メモリ 82,060 KB
実行使用メモリ 279,448 KB
最終ジャッジ日時 2025-03-31 18:02:19
合計ジャッジ時間 5,648 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 23 WA * 1 TLE * 1 -- * 36
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353
primitive_root = 3

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    S = sys.stdin.readline().strip()
    
    # Collect i's and n's positions
    i_pos = []
    n_pos = []
    for i, c in enumerate(S):
        if c == 'i':
            i_pos.append(i)
        elif c == 'n':
            n_pos.append(i)
    
    # Compute cnt[m]: number of i-n pairs with exactly m characters between
    cnt = {}
    max_m = 0
    for i in i_pos:
        for n in n_pos:
            if i >= n:
                continue
            m = n - i - 1
            cnt[m] = cnt.get(m, 0) + 1
            if m > max_m:
                max_m = m
    
    # Handle the case where cnt is empty (no i-n pairs)
    if not cnt:
        print(0)
        return
    
    # Precompute factorial and inv_factorial
    size = max_m if max_m >= 0 else 0
    fact = [1] * (size + 1)
    for i in range(1, size + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (size + 1)
    inv_fact[size] = pow(fact[size], MOD-2, MOD)
    for i in range(size-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    # Prepare array A and B
    len_A = max_m + 1
    A = [0] * len_A
    for m in cnt:
        if m >= len_A:
            continue  # theoretically not possible
        A[m] = cnt[m] * fact[m] % MOD
    
    B = [0] * (max_m + 1)
    for m in range(max_m + 1):
        if m <= size:
            B[m] = inv_fact[m]
        else:
            B[m] = 0  # inv_fact[m] is zero beyond size, but since we precomputed up to size, this is redundant
    
    # Create B_rev by reversing B
    B_rev = B[::-1]
    
    # Convolve A and B_rev
    len_A = len(A)
    len_B_rev = len(B_rev)
    len_conv = len_A + len_B_rev - 1
    n = 1
    while n < len_conv:
        n <<= 1
    
    # Pad with zeros
    A_pad = A + [0] * (n - len_A)
    B_rev_pad = B_rev + [0] * (n - len_B_rev)
    
    # NTT functions (taken from a standard implementation)
    def ntt_transform(a, invert):
        n = len(a)
        j = 0
        for i in range(1, n):
            bit = n >> 1
            while j >= bit:
                j -= bit
                bit >>= 1
            j += bit
            if i < j:
                a[i], a[j] = a[j], a[i]
        log_n = (n).bit_length() - 1
        for length in range(1, log_n + 1):
            clen = 1 << length
            ang = pow(primitive_root, (MOD - 1) // clen, MOD)
            if invert:
                ang = pow(ang, MOD - 2, MOD)
            for i in range(0, n, clen):
                w = 1
                for j_ in range(clen // 2):
                    u = a[i + j_]
                    v = a[i + j_ + clen // 2] * w % MOD
                    a[i + j_] = (u + v) % MOD
                    a[i + j_ + clen // 2] = (u - v) % MOD
                    w = w * ang % MOD
        if invert:
            inv_n = pow(n, MOD - 2, MOD)
            for i in range(n):
                a[i] = a[i] * inv_n % MOD
        return a
    
    # Perform NTT on A and B_rev
    A_ntt = A_pad.copy()
    ntt_transform(A_ntt, False)
    B_rev_ntt = B_rev_pad.copy()
    ntt_transform(B_rev_ntt, False)
    
    # Multiply point-wise
    C_ntt = [ (a * b) % MOD for a, b in zip(A_ntt, B_rev_ntt) ]
    
    # Inverse NTT
    ntt_transform(C_ntt, True)
    
    # Extract the first len_conv elements
    conv = C_ntt[:len_conv]
    
    # Now compute X_k for each k
    X = [0] * N
    len_B = len(B)
    for k in range(N - 1):
        if k > max_m:
            X_k_val = 0
        else:
            # The position in convolution is k + (len_B - 1)
            pos = k + (len_B - 1)
            if pos >= len(conv):
                c = 0
            else:
                c = conv[pos] % MOD
            X_k_val = c * inv_fact[k] % MOD
        X[k] = X_k_val
    
    # Compute XOR for X_0 to X_{N-2}
    xor_result = 0
    for k in range(N - 1):
        xor_result ^= X[k]
    
    print(xor_result % MOD)

if __name__ == '__main__':
    main()
0