結果

問題 No.1839 Concatenation Matrix
ユーザー gew1fw
提出日時 2025-06-12 16:21:46
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,533 ms / 3,500 ms
コード長 3,589 bytes
コンパイル時間 226 ms
コンパイル使用メモリ 82,072 KB
実行使用メモリ 163,528 KB
最終ジャッジ日時 2025-06-12 16:22:09
合計ジャッジ時間 13,548 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    sys.setrecursionlimit(1 << 25)
    N_and_rest = sys.stdin.read().split()
    N = int(N_and_rest[0])
    a = list(map(int, N_and_rest[1:N+1]))
    
    if N == 1:
        print(a[0] % MOD)
        return
    
    # Precompute pow10_i for i from 2 to N
    MOD_MINUS_1 = 998244352
    e_list = [0] * (N + 1)  # e_list[i] is for i >=2
    e = 1
    e_list[2] = e
    for i in range(3, N+1):
        e = (e * 2) % MOD_MINUS_1
        e_list[i] = e
    
    pow10_list = [0] * (N + 1)
    for i in range(2, N+1):
        pow10_list[i] = pow(10, e_list[i], MOD)
    
    # Generate the list of (p_i x + 1) for i=2 to N
    terms = []
    for i in range(2, N+1):
        terms.append((pow10_list[i], 1))  # Represents p_i x + 1
    
    # Divide and conquer to compute the product of the terms
    def multiply_polynomials(a, b):
        n = 1
        len_a = len(a)
        len_b = len(b)
        while n < len_a + len_b -1:
            n <<=1
        a_ntt = a + [0]*(n - len_a)
        b_ntt = b + [0]*(n - len_b)
        a_ntt = ntt(a_ntt)
        b_ntt = ntt(b_ntt)
        c_ntt = [(x*y) % MOD for x,y in zip(a_ntt, b_ntt)]
        c = ntt(c_ntt, invert=True)
        del c[len(a)+len(b)-1:]
        return c
    
    def product(l, r):
        if l == r:
            p, c = terms[l]
            return [c, p]  # Represents c + p x
        mid = (l + r) // 2
        left = product(l, mid)
        right = product(mid+1, r)
        return multiply_polynomials(left, right)
    
    if not terms:
        Q = [1]
    else:
        Q = product(0, len(terms)-1)
    
    # Compute P(x) = x * Q(x) mod x^N -1
    if not Q:
        P = []
    else:
        P = [0] * (len(Q)+1)
        for i in range(len(Q)):
            P[i+1] = Q[i]
        P[0] = (P[0] + Q[-1]) % MOD
        P = P[:N]
        if len(P) < N:
            P += [0]*(N - len(P))
        else:
            P = P[:N]
    
    # Compute G_1(x)
    G1 = a
    
    # Compute cyclic convolution of P and G1 modulo x^N -1
    def cyclic_convolution(P, G, N):
        # Pad to the next power of two >= N
        size = 1
        while size < 2*N:
            size <<=1
        P_pad = P + [0]*(size - len(P))
        G_pad = G + [0]*(size - len(G))
        P_ntt = ntt(P_pad.copy())
        G_ntt = ntt(G_pad.copy())
        C_ntt = [(p * g) % MOD for p, g in zip(P_ntt, G_ntt)]
        C = ntt(C_ntt, invert=True)
        res = [0]*N
        for i in range(size):
            idx = i % N
            res[idx] = (res[idx] + C[i]) % MOD
        return res
    
    result = cyclic_convolution(P, G1, N)
    
    for x in result:
        print(x % MOD)

def ntt(a, invert=False):
    n = len(a)
    log_n = (n).bit_length() -1
    rev = [0]*n
    for i in range(n):
        rev[i] = rev[i >>1] >>1 
        if i &1:
            rev[i] |= n >>1
        if i < rev[i]:
            a[i], a[rev[i]] = a[rev[i]], a[i]
    
    root = pow(3, (MOD-1)//n, MOD) if not invert else pow(3, MOD-1 - (MOD-1)//n, MOD)
    roots = [1]*n
    for i in range(1, n):
        roots[i] = roots[i-1] * root % MOD
    
    for m in range(1, log_n+1):
        m_h = 1 << (m-1)
        m_len = 1 << m
        for i in range(0, n, m_len):
            for j in range(m_h):
                u = a[i + j]
                v = a[i + j + m_h] * roots[j * (n >> m)] % MOD
                a[i + j] = (u + v) % MOD
                a[i + j + m_h] = (u - v) % MOD
    if invert:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

if __name__ == '__main__':
    main()
0