結果
問題 |
No.1145 Sums of Powers
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:30:48 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,335 bytes |
コンパイル時間 | 529 ms |
コンパイル使用メモリ | 81,816 KB |
実行使用メモリ | 166,572 KB |
最終ジャッジ日時 | 2025-04-16 16:32:47 |
合計ジャッジ時間 | 4,053 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 3 TLE * 1 -- * 2 |
ソースコード
import sys mod = 998244353 root = 3 # Primitive root for mod 998244353 def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n-1): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] log_n = (n).bit_length() - 1 for s in range(1, log_n + 1): m = 1 << s w_m = pow(root, (mod-1) // m, mod) if invert: w_m = pow(w_m, mod-2, mod) for k in range(0, n, m): w = 1 for j in range(m//2): t = (w * a[k + j + m//2]) % mod u = a[k + j] % mod a[k + j] = (u + t) % mod a[k + j + m//2] = (u - t) % mod w = (w * w_m) % mod if invert: inv_n = pow(n, mod-2, mod) for i in range(n): a[i] = (a[i] * inv_n) % mod return a def multiply(a, b): len_a = len(a) len_b = len(b) if len_a == 0 or len_b == 0: return [] n = 1 while n < len_a + len_b - 1: n <<= 1 fa = a + [0] * (n - len_a) fb = b + [0] * (n - len_b) fa = ntt(fa) fb = ntt(fb) for i in range(n): fa[i] = (fa[i] * fb[i]) % mod fa = ntt(fa, invert=True) res = [x % mod for x in fa[:len_a + len_b - 1]] return res def product(polys): if not polys: return [1] import heapq class HeapNode: def __init__(self, poly, l, r): self.poly = poly self.l = l self.r = r self.size = len(poly) def __lt__(self, other): return self.size < other.size heap = [] for p in polys: heapq.heappush(heap, HeapNode(p, 0, 0)) while len(heap) > 1: a = heapq.heappop(heap) b = heapq.heappop(heap) new_poly = multiply(a.poly, b.poly) heapq.heappush(heap, HeapNode(new_poly, a.l, b.r)) return heap[0].poly def inverse(a, m): if m == 0: return [] g = [pow(a[0], mod-2, mod)] n = 1 while n < m: new_n = min(n * 2, m) a_trunc = a[:new_n] + [0] * (new_n - len(a[:new_n])) fg = multiply(a_trunc, g) fg = fg[:new_n] subtract = [(mod - fg[i]) % mod for i in range(new_n)] subtract[0] = (subtract[0] + 2) % mod g_new = multiply(g, subtract) g_new = g_new[:new_n] g = g_new n = new_n return g[:m] def main(): n, m = map(int, sys.stdin.readline().split()) a = list(map(int, sys.stdin.readline().split())) if n == 0: print(' '.join(['0'] * m)) return polys = [] for ai in a: polys.append([1, (-ai) % mod]) P = product(polys) len_P = len(P) P_prime = [] for i in range(1, len_P): coeff = (i * P[i]) % mod P_prime.append(coeff) len_P_prime = len(P_prime) if len_P_prime < m: P_prime += [0] * (m - len_P_prime) else: P_prime = P_prime[:m] neg_P_prime = [(-x) % mod for x in P_prime] inv_P = inverse(P, m) S = multiply(neg_P_prime, inv_P) S = S[:m] output = [] for k in range(m): if k < len(S): output.append(str(S[k] % mod)) else: output.append('0') print(' '.join(output)) if __name__ == "__main__": main()