結果

問題 No.1839 Concatenation Matrix
ユーザー lam6er
提出日時 2025-03-31 18:01:09
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,089 bytes
コンパイル時間 185 ms
コンパイル使用メモリ 82,668 KB
実行使用メモリ 190,096 KB
最終ジャッジ日時 2025-03-31 18:02:15
合計ジャッジ時間 5,045 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 12
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
PHI = MOD - 1

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    n = int(data[0])
    a = list(map(int, data[1:n+1]))
    
    current = [x % MOD for x in a]
    
    k = min(23, n-1)
    
    for step in range(1, k+1):
        expon = 1 << (step-1)
        multiplier = pow(10, expon, MOD)
        next_current = []
        for j in range(n):
            nj = (j + 1) % n
            val = (current[j] * multiplier + current[nj]) % MOD
            next_current.append(val)
        current = next_current
    
    remaining_steps = (n-1) - k
    if remaining_steps > 0:
        m = remaining_steps
        max_fact = m
        fact = [1] * (max_fact + 1)
        for i in range(1, max_fact + 1):
            fact[i] = fact[i-1] * i % MOD
        inv_fact = [1] * (max_fact + 1)
        inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
        for i in range(max_fact -1, -1, -1):
            inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
        
        def comb(m, k):
            if k <0 or k > m:
                return 0
            return fact[m] * inv_fact[k] % MOD * inv_fact[m - k] % MOD
        
        G = [0] * n
        for k in range(0, m + 1):
            pos = k % n
            G[pos] = (G[pos] + comb(m, k)) % MOD
        
        def ntt(a, invert=False):
            n = len(a)
            rev = list(range(n))
            for i in range(1, n):
                rev[i] = rev[i >> 1] >> 1
                if i & 1:
                    rev[i] |= n >> 1
                if i < rev[i]:
                    a[i], a[rev[i]] = a[rev[i]], a[i]
            log_n = (n).bit_length() - 1
            root = pow(3, (MOD - 1) // n, MOD) if not invert else pow(3, MOD-1 - (MOD -1) // n, MOD)
            roots = [1] * (n // 2)
            for i in range(1, len(roots)):
                roots[i] = roots[i-1] * root % MOD
            for m in range(1, log_n + 1):
                m_h = 1 << (m-1)
                m_len = m_h << 1
                for i in range(0, n, m_len):
                    for j in range(m_h):
                        u = a[i + j]
                        v = a[i + j + m_h] * roots[j * (n >> m)] % MOD
                        a[i + j] = (u + v) % MOD
                        a[i + j + m_h] = (u - v) % MOD
            if invert:
                inv_n = pow(n, MOD-2, MOD)
                for i in range(n):
                    a[i] = a[i] * inv_n % MOD
            return a
        
        size = 1
        while size < n:
            size <<= 1
        size <<= 1
        
        a_pad = current[:] + [0] * (size - n)
        g_pad = G[:] + [0] * (size - n)
        
        a_pad_ntt = ntt(a_pad.copy())
        g_pad_ntt = ntt(g_pad.copy())
        
        conv_ntt = [ (a * b) % MOD for a, b in zip(a_pad_ntt, g_pad_ntt)]
        
        conv = ntt(conv_ntt, invert=True)
        
        result = [0] * n
        for i in range(n):
            result[i] = (conv[i] + conv[i + n]) % MOD
        
        current = result
    
    for x in current:
        print(x)

if __name__ == '__main__':
    main()
0