結果
| 問題 |
No.2272 多項式乗算 mod 258280327
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 21:35:27 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,567 bytes |
| コンパイル時間 | 572 ms |
| コンパイル使用メモリ | 82,408 KB |
| 実行使用メモリ | 324,472 KB |
| 最終ジャッジ日時 | 2025-04-15 21:37:19 |
| 合計ジャッジ時間 | 10,060 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 22 WA * 7 TLE * 2 -- * 2 |
ソースコード
import sys
import math
import cmath
def readints():
return list(map(int, sys.stdin.readline().split()))
mod_val = 258280327
def fft(a, invert):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
angle = 2 * cmath.pi / length * (-1 if invert else 1)
wlen = complex(math.cos(angle), math.sin(angle))
for i in range(0, n, length):
w = complex(1)
for j in range(length // 2):
u = a[i + j]
v = a[i + j + length // 2] * w
a[i + j] = u + v
a[i + j + length // 2] = u - v
w *= wlen
length <<= 1
if invert:
for i in range(n):
a[i] /= n
def multiply(a, b):
n = 1
max_len = len(a) + len(b)
while n < max_len:
n <<= 1
fa = [complex(0)] * n
fb = [complex(0)] * n
for i in range(len(a)):
fa[i] = complex(a[i])
for i in range(len(b)):
fb[i] = complex(b[i])
fft(fa, False)
fft(fb, False)
for i in range(n):
fa[i] *= fb[i]
fft(fa, True)
res = [0] * n
for i in range(n):
res[i] = int(round(fa[i].real))
return res
def main():
N = int(sys.stdin.readline())
F = list(map(int, sys.stdin.readline().split()))
M = int(sys.stdin.readline())
G = list(map(int, sys.stdin.readline().split()))
F_mod = [x % mod_val for x in F]
G_mod = [x % mod_val for x in G]
split = 17
mask = (1 << split) - 1
F_lo = [x & mask for x in F_mod]
F_hi = [x >> split for x in F_mod]
G_lo = [x & mask for x in G_mod]
G_hi = [x >> split for x in G_mod]
conv_lo_lo = multiply(F_lo, G_lo)
conv_lo_hi = multiply(F_lo, G_hi)
conv_hi_lo = multiply(F_hi, G_lo)
conv_hi_hi = multiply(F_hi, G_hi)
max_k = len(F_mod) + len(G_mod) - 2
H = [0] * (max_k + 1)
for k in range(len(H)):
ll = conv_lo_lo[k] if k < len(conv_lo_lo) else 0
lh = conv_lo_hi[k] if k < len(conv_lo_hi) else 0
hl = conv_hi_lo[k] if k < len(conv_hi_lo) else 0
hh = conv_hi_hi[k] if k < len(conv_hi_hi) else 0
total = ll + ((lh + hl) << split) + (hh << (2 * split))
H[k] = total % mod_val
L = max_k
while L > 0 and H[L] == 0:
L -= 1
print(L)
print(' '.join(map(str, H[:L+1])))
if __name__ == "__main__":
main()
lam6er