結果

問題 No.2272 多項式乗算 mod 258280327
ユーザー lam6er
提出日時 2025-04-16 15:32:09
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,567 bytes
コンパイル時間 182 ms
コンパイル使用メモリ 82,000 KB
実行使用メモリ 288,524 KB
最終ジャッジ日時 2025-04-16 15:37:27
合計ジャッジ時間 9,348 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 22 WA * 7 TLE * 2 -- * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
import cmath

def readints():
    return list(map(int, sys.stdin.readline().split()))

mod_val = 258280327

def fft(a, invert):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    length = 2
    while length <= n:
        angle = 2 * cmath.pi / length * (-1 if invert else 1)
        wlen = complex(math.cos(angle), math.sin(angle))
        for i in range(0, n, length):
            w = complex(1)
            for j in range(length // 2):
                u = a[i + j]
                v = a[i + j + length // 2] * w
                a[i + j] = u + v
                a[i + j + length // 2] = u - v
                w *= wlen
        length <<= 1
    if invert:
        for i in range(n):
            a[i] /= n

def multiply(a, b):
    n = 1
    max_len = len(a) + len(b)
    while n < max_len:
        n <<= 1
    fa = [complex(0)] * n
    fb = [complex(0)] * n
    for i in range(len(a)):
        fa[i] = complex(a[i])
    for i in range(len(b)):
        fb[i] = complex(b[i])
    fft(fa, False)
    fft(fb, False)
    for i in range(n):
        fa[i] *= fb[i]
    fft(fa, True)
    res = [0] * n
    for i in range(n):
        res[i] = int(round(fa[i].real))
    return res

def main():
    N = int(sys.stdin.readline())
    F = list(map(int, sys.stdin.readline().split()))
    M = int(sys.stdin.readline())
    G = list(map(int, sys.stdin.readline().split()))
    
    F_mod = [x % mod_val for x in F]
    G_mod = [x % mod_val for x in G]
    
    split = 17
    mask = (1 << split) - 1
    
    F_lo = [x & mask for x in F_mod]
    F_hi = [x >> split for x in F_mod]
    G_lo = [x & mask for x in G_mod]
    G_hi = [x >> split for x in G_mod]
    
    conv_lo_lo = multiply(F_lo, G_lo)
    conv_lo_hi = multiply(F_lo, G_hi)
    conv_hi_lo = multiply(F_hi, G_lo)
    conv_hi_hi = multiply(F_hi, G_hi)
    
    max_k = len(F_mod) + len(G_mod) - 2
    H = [0] * (max_k + 1)
    for k in range(len(H)):
        ll = conv_lo_lo[k] if k < len(conv_lo_lo) else 0
        lh = conv_lo_hi[k] if k < len(conv_lo_hi) else 0
        hl = conv_hi_lo[k] if k < len(conv_hi_lo) else 0
        hh = conv_hi_hi[k] if k < len(conv_hi_hi) else 0
        total = ll + ((lh + hl) << split) + (hh << (2 * split))
        H[k] = total % mod_val
    
    L = max_k
    while L > 0 and H[L] == 0:
        L -= 1
    
    print(L)
    print(' '.join(map(str, H[:L+1])))

if __name__ == "__main__":
    main()
0