結果
| 問題 | No.1145 Sums of Powers |
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:48:24 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,160 bytes |
| コンパイル時間 | 392 ms |
| コンパイル使用メモリ | 82,128 KB |
| 実行使用メモリ | 160,916 KB |
| 最終ジャッジ日時 | 2025-04-16 00:51:26 |
| 合計ジャッジ時間 | 4,336 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 3 TLE * 3 |
ソースコード
import sys
MOD = 998244353
ROOT = 3 # Primitive root for MOD
def ntt(a, invert=False):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
log_n = (n).bit_length() - 1
for s in range(1, log_n + 1):
m = 1 << s
w_m = pow(ROOT, (MOD - 1) // m, MOD)
if invert:
w_m = pow(w_m, MOD - 2, MOD)
for k in range(0, n, m):
w = 1
for j in range(m // 2):
t = (w * a[k + j + m // 2]) % MOD
u = a[k + j]
a[k + j] = (u + t) % MOD
a[k + j + m // 2] = (u - t) % MOD
w = (w * w_m) % MOD
if invert:
inv_n = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = (a[i] * inv_n) % MOD
return a
def convolution(a, b):
len_a = len(a)
len_b = len(b)
n = 1
while n < len_a + len_b - 1:
n <<= 1
a += [0] * (n - len_a)
b += [0] * (n - len_b)
a = ntt(a)
b = ntt(b)
c = [(a[i] * b[i]) % MOD for i in range(n)]
c = ntt(c, invert=True)
return c[:len_a + len_b - 1]
def multiply(a, b):
return convolution(a, b)
def product_helper(a_list, l, r):
if l == r:
return [1, (-a_list[l]) % MOD]
mid = (l + r) // 2
left = product_helper(a_list, l, mid)
right = product_helper(a_list, mid + 1, r)
return multiply(left, right)
def compute_P(A):
if not A:
return [1]
return product_helper(A, 0, len(A) - 1)
def compute_derivative(P):
n = len(P)
if n == 0:
return []
P_deriv = [0] * (n - 1)
for i in range(1, n):
P_deriv[i - 1] = (i * P[i]) % MOD
return P_deriv
def inverse(P, m):
if not P:
return []
MOD = 998244353
g = [0] * m
g[0] = pow(P[0], MOD - 2, MOD)
current_len = 1
while current_len < m:
next_len = min(current_len * 2, m)
f = P[:next_len] + [0] * (next_len - len(P[:next_len]))
product = multiply(f[:current_len * 2], g[:current_len])
product = product[:next_len]
two_minus = [(-product[i]) % MOD for i in range(next_len)]
two_minus[0] = (two_minus[0] + 2) % MOD
increment = multiply(g[:current_len], two_minus[:current_len * 2])
increment = increment[:next_len]
for i in range(current_len, next_len):
g[i] = increment[i]
current_len = next_len
return g[:m]
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr += 1
M = int(input[ptr])
ptr += 1
A = list(map(int, input[ptr:ptr + N]))
ptr += N
if N == 0:
print(' '.join(['0'] * M))
return
P = compute_P(A)
P_deriv = compute_derivative(P)
Q = inverse(P, M)
len_P_deriv = len(P_deriv)
P_deriv_truncated = P_deriv[:M] + [0] * (M - len_P_deriv) if len_P_deriv < M else P_deriv[:M]
R = multiply(P_deriv_truncated, Q)
R = R[:M]
R = [(-x) % MOD for x in R]
print(' '.join(map(str, R)))
if __name__ == "__main__":
main()
lam6er