結果
| 問題 | No.1145 Sums of Powers |
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-26 15:50:31 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,151 bytes |
| コンパイル時間 | 447 ms |
| コンパイル使用メモリ | 82,168 KB |
| 実行使用メモリ | 191,644 KB |
| 最終ジャッジ日時 | 2025-03-26 15:51:43 |
| 合計ジャッジ時間 | 9,801 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 3 TLE * 3 |
ソースコード
import sys
MOD = 998244353
g = 3
gi = 332748118
def readints():
return list(map(int, sys.stdin.readline().split()))
def mod_pow(a, b, mod):
res = 1
a %= mod
while b > 0:
if b % 2 == 1:
res = res * a % mod
a = a * a % mod
b //= 2
return res
def ntt(a, invert=False):
n = len(a)
rev = [0]*n
for i in range(n):
rev[i] = rev[i >> 1] >> 1
if i & 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]
log_n = (n).bit_length() - 1
root = mod_pow(g, (MOD-1)//n, MOD) if not invert else mod_pow(gi, (MOD-1)//n, MOD)
roots = [1]*(n//2)
for i in range(1, len(roots)):
roots[i] = roots[i-1] * root % MOD
for m in range(1, log_n+1):
m_h = 1 << m
m_2 = m_h >> 1
for i in range(0, n, m_h):
for j in range(m_2):
u = a[i + j]
v = a[i + j + m_2] * roots[j * (n // m_h)] % MOD
a[i + j] = (u + v) % MOD
a[i + j + m_2] = (u - v) % MOD
if invert:
inv_n = mod_pow(n, MOD-2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
def multiply(a, b):
n = 1
while n < len(a) + len(b) - 1:
n <<= 1
fa = a + [0]*(n - len(a))
fb = b + [0]*(n - len(b))
fa = ntt(fa)
fb = ntt(fb)
for i in range(n):
fa[i] = fa[i] * fb[i] % MOD
fa = ntt(fa, invert=True)
del fa[len(a)+len(b)-1:]
return [x % MOD for x in fa]
def product(polys, l, r, max_degree):
if l == r:
return polys[l]
mid = (l + r) // 2
left = product(polys, l, mid, max_degree)
right = product(polys, mid+1, r, max_degree)
res = multiply(left, right)
if len(res) > max_degree + 1:
res = res[:max_degree+1]
return res
def inverse(a, m):
n = 1
while n < m:
n <<= 1
a_deg = len(a)
a = a[:] + [0]*(n - a_deg)
res = [mod_pow(a[0], MOD-2, MOD)]
while len(res) < m:
next_len = min(len(res) << 1, m)
tmp = a[:next_len] + [0]*(next_len - len(a[:next_len]))
frs = multiply(res, res)
frs = multiply(frs, tmp[:next_len])
frs = frs[:next_len]
new_res = [ (2 * res[i] if i < len(res) else 0) - frs[i] for i in range(next_len) ]
new_res = [x % MOD for x in new_res]
res = new_res
return res[:m]
def main():
N, M = readints()
A = readints()
max_degree = M + 1
if N == 0:
print('0 ' * M)
return
polys = []
for a in A:
polys.append([1, (MOD - a) % MOD])
P = product(polys, 0, len(polys)-1, max_degree)
P_deriv = [0] * (M + 1)
for i in range(1, min(len(P), M + 2)):
if i - 1 > M:
break
P_deriv[i-1] = i * P[i] % MOD
inv_P = inverse(P, M + 1)
Q = multiply(P_deriv[:M + 1], inv_P)
Q = Q[:M]
G = [0] * (M + 1)
G[0] = N % MOD
for i in range(1, M + 1):
if i - 1 >= len(Q):
G[i] = 0
else:
G[i] = (-Q[i - 1]) % MOD
print(' '.join(map(str, G[1:M + 1])))
if __name__ == "__main__":
main()
lam6er