結果
問題 |
No.1145 Sums of Powers
|
ユーザー |
![]() |
提出日時 | 2025-03-26 15:50:31 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,151 bytes |
コンパイル時間 | 447 ms |
コンパイル使用メモリ | 82,168 KB |
実行使用メモリ | 191,644 KB |
最終ジャッジ日時 | 2025-03-26 15:51:43 |
合計ジャッジ時間 | 9,801 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 3 TLE * 3 |
ソースコード
import sys MOD = 998244353 g = 3 gi = 332748118 def readints(): return list(map(int, sys.stdin.readline().split())) def mod_pow(a, b, mod): res = 1 a %= mod while b > 0: if b % 2 == 1: res = res * a % mod a = a * a % mod b //= 2 return res def ntt(a, invert=False): n = len(a) rev = [0]*n for i in range(n): rev[i] = rev[i >> 1] >> 1 if i & 1: rev[i] |= n >> 1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] log_n = (n).bit_length() - 1 root = mod_pow(g, (MOD-1)//n, MOD) if not invert else mod_pow(gi, (MOD-1)//n, MOD) roots = [1]*(n//2) for i in range(1, len(roots)): roots[i] = roots[i-1] * root % MOD for m in range(1, log_n+1): m_h = 1 << m m_2 = m_h >> 1 for i in range(0, n, m_h): for j in range(m_2): u = a[i + j] v = a[i + j + m_2] * roots[j * (n // m_h)] % MOD a[i + j] = (u + v) % MOD a[i + j + m_2] = (u - v) % MOD if invert: inv_n = mod_pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def multiply(a, b): n = 1 while n < len(a) + len(b) - 1: n <<= 1 fa = a + [0]*(n - len(a)) fb = b + [0]*(n - len(b)) fa = ntt(fa) fb = ntt(fb) for i in range(n): fa[i] = fa[i] * fb[i] % MOD fa = ntt(fa, invert=True) del fa[len(a)+len(b)-1:] return [x % MOD for x in fa] def product(polys, l, r, max_degree): if l == r: return polys[l] mid = (l + r) // 2 left = product(polys, l, mid, max_degree) right = product(polys, mid+1, r, max_degree) res = multiply(left, right) if len(res) > max_degree + 1: res = res[:max_degree+1] return res def inverse(a, m): n = 1 while n < m: n <<= 1 a_deg = len(a) a = a[:] + [0]*(n - a_deg) res = [mod_pow(a[0], MOD-2, MOD)] while len(res) < m: next_len = min(len(res) << 1, m) tmp = a[:next_len] + [0]*(next_len - len(a[:next_len])) frs = multiply(res, res) frs = multiply(frs, tmp[:next_len]) frs = frs[:next_len] new_res = [ (2 * res[i] if i < len(res) else 0) - frs[i] for i in range(next_len) ] new_res = [x % MOD for x in new_res] res = new_res return res[:m] def main(): N, M = readints() A = readints() max_degree = M + 1 if N == 0: print('0 ' * M) return polys = [] for a in A: polys.append([1, (MOD - a) % MOD]) P = product(polys, 0, len(polys)-1, max_degree) P_deriv = [0] * (M + 1) for i in range(1, min(len(P), M + 2)): if i - 1 > M: break P_deriv[i-1] = i * P[i] % MOD inv_P = inverse(P, M + 1) Q = multiply(P_deriv[:M + 1], inv_P) Q = Q[:M] G = [0] * (M + 1) G[0] = N % MOD for i in range(1, M + 1): if i - 1 >= len(Q): G[i] = 0 else: G[i] = (-Q[i - 1]) % MOD print(' '.join(map(str, G[1:M + 1]))) if __name__ == "__main__": main()