結果
問題 |
No.931 Multiplicative Convolution
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:18:55 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 324 ms / 2,000 ms |
コード長 | 2,936 bytes |
コンパイル時間 | 480 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 130,108 KB |
最終ジャッジ日時 | 2025-04-16 16:20:20 |
合計ジャッジ時間 | 4,641 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 14 |
ソースコード
import sys mod = 998244353 root = 3 def factor(n): factors = {} while n % 2 == 0: factors[2] = factors.get(2, 0) + 1 n = n // 2 i = 3 while i * i <= n: while n % i == 0: factors[i] = factors.get(i, 0) + 1 n = n // i i += 2 if n > 1: factors[n] = 1 return factors def find_primitive_root(p): if p == 2: return 1 phi = p - 1 factors = factor(phi) for g in range(2, p): flag = True for q in factors: if pow(g, phi // q, p) == 1: flag = False break if flag: return g return -1 def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] length = 2 while length <= n: half = length // 2 step = pow(root, (mod - 1) // length, mod) if invert: step = pow(step, mod - 2, mod) for i in range(0, n, length): w = 1 for j in range(i, i + half): u = a[j] v = a[j + half] * w % mod a[j] = (u + v) % mod a[j + half] = (u - v) % mod w = w * step % mod length <<= 1 if invert: inv_n = pow(n, mod - 2, mod) for i in range(n): a[i] = a[i] * inv_n % mod return a def convolve(a, b): len_a = len(a) len_b = len(b) max_len = len_a + len_b - 1 n = 1 while n < max_len: n <<= 1 fa = a.copy() + [0] * (n - len_a) fb = b.copy() + [0] * (n - len_b) fa = ntt(fa) fb = ntt(fb) for i in range(n): fa[i] = fa[i] * fb[i] % mod fa = ntt(fa, invert=True) return fa[:max_len] def main(): input = sys.stdin.read().split() ptr = 0 P = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr + P-1])) ptr += P-1 B = list(map(int, input[ptr:ptr + P-1])) ptr += P-1 if P == 2: c = (A[0] * B[0]) % mod print(c) return g = find_primitive_root(P) exp_table = [pow(g, m, P) for m in range(P-1)] log_table = [0] * (P) for m in range(P-1): log_table[exp_table[m]] = m a = [0] * (P-1) b = [0] * (P-1) for m in range(P-1): i = exp_table[m] a[m] = A[i-1] b[m] = B[i-1] linear_conv = convolve(a, b) N = P-1 cyclic_conv = [0] * N for m in range(N): cyclic_conv[m] = linear_conv[m] if m + N < len(linear_conv): cyclic_conv[m] = (cyclic_conv[m] + linear_conv[m + N]) % mod C = [0] * (P-1) for k in range(1, P): m = log_table[k] C[k-1] = cyclic_conv[m] % mod print(' '.join(map(str, C))) if __name__ == '__main__': main()