結果

問題 No.1145 Sums of Powers
ユーザー lam6er
提出日時 2025-04-16 00:48:21
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,335 bytes
コンパイル時間 174 ms
コンパイル使用メモリ 81,568 KB
実行使用メモリ 199,836 KB
最終ジャッジ日時 2025-04-16 00:51:20
合計ジャッジ時間 3,977 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 3 TLE * 1 -- * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
mod = 998244353
root = 3  # Primitive root for mod 998244353

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n-1):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n).bit_length() - 1
    for s in range(1, log_n + 1):
        m = 1 << s
        w_m = pow(root, (mod-1) // m, mod)
        if invert:
            w_m = pow(w_m, mod-2, mod)
        for k in range(0, n, m):
            w = 1
            for j in range(m//2):
                t = (w * a[k + j + m//2]) % mod
                u = a[k + j] % mod
                a[k + j] = (u + t) % mod
                a[k + j + m//2] = (u - t) % mod
                w = (w * w_m) % mod
    if invert:
        inv_n = pow(n, mod-2, mod)
        for i in range(n):
            a[i] = (a[i] * inv_n) % mod
    return a

def multiply(a, b):
    len_a = len(a)
    len_b = len(b)
    if len_a == 0 or len_b == 0:
        return []
    n = 1
    while n < len_a + len_b - 1:
        n <<= 1
    fa = a + [0] * (n - len_a)
    fb = b + [0] * (n - len_b)
    fa = ntt(fa)
    fb = ntt(fb)
    for i in range(n):
        fa[i] = (fa[i] * fb[i]) % mod
    fa = ntt(fa, invert=True)
    res = [x % mod for x in fa[:len_a + len_b - 1]]
    return res

def product(polys):
    if not polys:
        return [1]
    import heapq
    class HeapNode:
        def __init__(self, poly, l, r):
            self.poly = poly
            self.l = l
            self.r = r
            self.size = len(poly)
        def __lt__(self, other):
            return self.size < other.size
    heap = []
    for p in polys:
        heapq.heappush(heap, HeapNode(p, 0, 0))
    while len(heap) > 1:
        a = heapq.heappop(heap)
        b = heapq.heappop(heap)
        new_poly = multiply(a.poly, b.poly)
        heapq.heappush(heap, HeapNode(new_poly, a.l, b.r))
    return heap[0].poly

def inverse(a, m):
    if m == 0:
        return []
    g = [pow(a[0], mod-2, mod)]
    n = 1
    while n < m:
        new_n = min(n * 2, m)
        a_trunc = a[:new_n] + [0] * (new_n - len(a[:new_n]))
        fg = multiply(a_trunc, g)
        fg = fg[:new_n]
        subtract = [(mod - fg[i]) % mod for i in range(new_n)]
        subtract[0] = (subtract[0] + 2) % mod
        g_new = multiply(g, subtract)
        g_new = g_new[:new_n]
        g = g_new
        n = new_n
    return g[:m]

def main():
    n, m = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))
    if n == 0:
        print(' '.join(['0'] * m))
        return
    polys = []
    for ai in a:
        polys.append([1, (-ai) % mod])
    P = product(polys)
    len_P = len(P)
    P_prime = []
    for i in range(1, len_P):
        coeff = (i * P[i]) % mod
        P_prime.append(coeff)
    len_P_prime = len(P_prime)
    if len_P_prime < m:
        P_prime += [0] * (m - len_P_prime)
    else:
        P_prime = P_prime[:m]
    neg_P_prime = [(-x) % mod for x in P_prime]
    inv_P = inverse(P, m)
    S = multiply(neg_P_prime, inv_P)
    S = S[:m]
    output = []
    for k in range(m):
        if k < len(S):
            output.append(str(S[k] % mod))
        else:
            output.append('0')
    print(' '.join(output))

if __name__ == "__main__":
    main()
0