結果
| 問題 | No.1145 Sums of Powers |
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:44:59 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,335 bytes |
| コンパイル時間 | 227 ms |
| コンパイル使用メモリ | 82,284 KB |
| 実行使用メモリ | 161,380 KB |
| 最終ジャッジ日時 | 2025-04-16 00:48:11 |
| 合計ジャッジ時間 | 4,098 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 3 TLE * 1 -- * 2 |
ソースコード
import sys
mod = 998244353
root = 3 # Primitive root for mod 998244353
def ntt(a, invert=False):
n = len(a)
j = 0
for i in range(1, n-1):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
log_n = (n).bit_length() - 1
for s in range(1, log_n + 1):
m = 1 << s
w_m = pow(root, (mod-1) // m, mod)
if invert:
w_m = pow(w_m, mod-2, mod)
for k in range(0, n, m):
w = 1
for j in range(m//2):
t = (w * a[k + j + m//2]) % mod
u = a[k + j] % mod
a[k + j] = (u + t) % mod
a[k + j + m//2] = (u - t) % mod
w = (w * w_m) % mod
if invert:
inv_n = pow(n, mod-2, mod)
for i in range(n):
a[i] = (a[i] * inv_n) % mod
return a
def multiply(a, b):
len_a = len(a)
len_b = len(b)
if len_a == 0 or len_b == 0:
return []
n = 1
while n < len_a + len_b - 1:
n <<= 1
fa = a + [0] * (n - len_a)
fb = b + [0] * (n - len_b)
fa = ntt(fa)
fb = ntt(fb)
for i in range(n):
fa[i] = (fa[i] * fb[i]) % mod
fa = ntt(fa, invert=True)
res = [x % mod for x in fa[:len_a + len_b - 1]]
return res
def product(polys):
if not polys:
return [1]
import heapq
class HeapNode:
def __init__(self, poly, l, r):
self.poly = poly
self.l = l
self.r = r
self.size = len(poly)
def __lt__(self, other):
return self.size < other.size
heap = []
for p in polys:
heapq.heappush(heap, HeapNode(p, 0, 0))
while len(heap) > 1:
a = heapq.heappop(heap)
b = heapq.heappop(heap)
new_poly = multiply(a.poly, b.poly)
heapq.heappush(heap, HeapNode(new_poly, a.l, b.r))
return heap[0].poly
def inverse(a, m):
if m == 0:
return []
g = [pow(a[0], mod-2, mod)]
n = 1
while n < m:
new_n = min(n * 2, m)
a_trunc = a[:new_n] + [0] * (new_n - len(a[:new_n]))
fg = multiply(a_trunc, g)
fg = fg[:new_n]
subtract = [(mod - fg[i]) % mod for i in range(new_n)]
subtract[0] = (subtract[0] + 2) % mod
g_new = multiply(g, subtract)
g_new = g_new[:new_n]
g = g_new
n = new_n
return g[:m]
def main():
n, m = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
if n == 0:
print(' '.join(['0'] * m))
return
polys = []
for ai in a:
polys.append([1, (-ai) % mod])
P = product(polys)
len_P = len(P)
P_prime = []
for i in range(1, len_P):
coeff = (i * P[i]) % mod
P_prime.append(coeff)
len_P_prime = len(P_prime)
if len_P_prime < m:
P_prime += [0] * (m - len_P_prime)
else:
P_prime = P_prime[:m]
neg_P_prime = [(-x) % mod for x in P_prime]
inv_P = inverse(P, m)
S = multiply(neg_P_prime, inv_P)
S = S[:m]
output = []
for k in range(m):
if k < len(S):
output.append(str(S[k] % mod))
else:
output.append('0')
print(' '.join(output))
if __name__ == "__main__":
main()
lam6er