結果
| 問題 |
No.931 Multiplicative Convolution
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 23:18:21 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 321 ms / 2,000 ms |
| コード長 | 2,936 bytes |
| コンパイル時間 | 162 ms |
| コンパイル使用メモリ | 81,652 KB |
| 実行使用メモリ | 129,760 KB |
| 最終ジャッジ日時 | 2025-04-15 23:19:43 |
| 合計ジャッジ時間 | 4,568 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 14 |
ソースコード
import sys
mod = 998244353
root = 3
def factor(n):
factors = {}
while n % 2 == 0:
factors[2] = factors.get(2, 0) + 1
n = n // 2
i = 3
while i * i <= n:
while n % i == 0:
factors[i] = factors.get(i, 0) + 1
n = n // i
i += 2
if n > 1:
factors[n] = 1
return factors
def find_primitive_root(p):
if p == 2:
return 1
phi = p - 1
factors = factor(phi)
for g in range(2, p):
flag = True
for q in factors:
if pow(g, phi // q, p) == 1:
flag = False
break
if flag:
return g
return -1
def ntt(a, invert=False):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
half = length // 2
step = pow(root, (mod - 1) // length, mod)
if invert:
step = pow(step, mod - 2, mod)
for i in range(0, n, length):
w = 1
for j in range(i, i + half):
u = a[j]
v = a[j + half] * w % mod
a[j] = (u + v) % mod
a[j + half] = (u - v) % mod
w = w * step % mod
length <<= 1
if invert:
inv_n = pow(n, mod - 2, mod)
for i in range(n):
a[i] = a[i] * inv_n % mod
return a
def convolve(a, b):
len_a = len(a)
len_b = len(b)
max_len = len_a + len_b - 1
n = 1
while n < max_len:
n <<= 1
fa = a.copy() + [0] * (n - len_a)
fb = b.copy() + [0] * (n - len_b)
fa = ntt(fa)
fb = ntt(fb)
for i in range(n):
fa[i] = fa[i] * fb[i] % mod
fa = ntt(fa, invert=True)
return fa[:max_len]
def main():
input = sys.stdin.read().split()
ptr = 0
P = int(input[ptr])
ptr += 1
A = list(map(int, input[ptr:ptr + P-1]))
ptr += P-1
B = list(map(int, input[ptr:ptr + P-1]))
ptr += P-1
if P == 2:
c = (A[0] * B[0]) % mod
print(c)
return
g = find_primitive_root(P)
exp_table = [pow(g, m, P) for m in range(P-1)]
log_table = [0] * (P)
for m in range(P-1):
log_table[exp_table[m]] = m
a = [0] * (P-1)
b = [0] * (P-1)
for m in range(P-1):
i = exp_table[m]
a[m] = A[i-1]
b[m] = B[i-1]
linear_conv = convolve(a, b)
N = P-1
cyclic_conv = [0] * N
for m in range(N):
cyclic_conv[m] = linear_conv[m]
if m + N < len(linear_conv):
cyclic_conv[m] = (cyclic_conv[m] + linear_conv[m + N]) % mod
C = [0] * (P-1)
for k in range(1, P):
m = log_table[k]
C[k-1] = cyclic_conv[m] % mod
print(' '.join(map(str, C)))
if __name__ == '__main__':
main()
lam6er