結果
問題 |
No.1145 Sums of Powers
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:31:35 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,160 bytes |
コンパイル時間 | 429 ms |
コンパイル使用メモリ | 81,908 KB |
実行使用メモリ | 160,760 KB |
最終ジャッジ日時 | 2025-04-16 16:33:25 |
合計ジャッジ時間 | 9,052 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 3 TLE * 3 |
ソースコード
import sys MOD = 998244353 ROOT = 3 # Primitive root for MOD def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] log_n = (n).bit_length() - 1 for s in range(1, log_n + 1): m = 1 << s w_m = pow(ROOT, (MOD - 1) // m, MOD) if invert: w_m = pow(w_m, MOD - 2, MOD) for k in range(0, n, m): w = 1 for j in range(m // 2): t = (w * a[k + j + m // 2]) % MOD u = a[k + j] a[k + j] = (u + t) % MOD a[k + j + m // 2] = (u - t) % MOD w = (w * w_m) % MOD if invert: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = (a[i] * inv_n) % MOD return a def convolution(a, b): len_a = len(a) len_b = len(b) n = 1 while n < len_a + len_b - 1: n <<= 1 a += [0] * (n - len_a) b += [0] * (n - len_b) a = ntt(a) b = ntt(b) c = [(a[i] * b[i]) % MOD for i in range(n)] c = ntt(c, invert=True) return c[:len_a + len_b - 1] def multiply(a, b): return convolution(a, b) def product_helper(a_list, l, r): if l == r: return [1, (-a_list[l]) % MOD] mid = (l + r) // 2 left = product_helper(a_list, l, mid) right = product_helper(a_list, mid + 1, r) return multiply(left, right) def compute_P(A): if not A: return [1] return product_helper(A, 0, len(A) - 1) def compute_derivative(P): n = len(P) if n == 0: return [] P_deriv = [0] * (n - 1) for i in range(1, n): P_deriv[i - 1] = (i * P[i]) % MOD return P_deriv def inverse(P, m): if not P: return [] MOD = 998244353 g = [0] * m g[0] = pow(P[0], MOD - 2, MOD) current_len = 1 while current_len < m: next_len = min(current_len * 2, m) f = P[:next_len] + [0] * (next_len - len(P[:next_len])) product = multiply(f[:current_len * 2], g[:current_len]) product = product[:next_len] two_minus = [(-product[i]) % MOD for i in range(next_len)] two_minus[0] = (two_minus[0] + 2) % MOD increment = multiply(g[:current_len], two_minus[:current_len * 2]) increment = increment[:next_len] for i in range(current_len, next_len): g[i] = increment[i] current_len = next_len return g[:m] def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 M = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr + N])) ptr += N if N == 0: print(' '.join(['0'] * M)) return P = compute_P(A) P_deriv = compute_derivative(P) Q = inverse(P, M) len_P_deriv = len(P_deriv) P_deriv_truncated = P_deriv[:M] + [0] * (M - len_P_deriv) if len_P_deriv < M else P_deriv[:M] R = multiply(P_deriv_truncated, Q) R = R[:M] R = [(-x) % MOD for x in R] print(' '.join(map(str, R))) if __name__ == "__main__": main()