結果

問題 No.1145 Sums of Powers
ユーザー lam6er
提出日時 2025-04-16 00:53:10
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,160 bytes
コンパイル時間 366 ms
コンパイル使用メモリ 81,720 KB
実行使用メモリ 160,464 KB
最終ジャッジ日時 2025-04-16 00:55:28
合計ジャッジ時間 8,646 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 3 TLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353
ROOT = 3  # Primitive root for MOD

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n).bit_length() - 1
    for s in range(1, log_n + 1):
        m = 1 << s
        w_m = pow(ROOT, (MOD - 1) // m, MOD)
        if invert:
            w_m = pow(w_m, MOD - 2, MOD)
        for k in range(0, n, m):
            w = 1
            for j in range(m // 2):
                t = (w * a[k + j + m // 2]) % MOD
                u = a[k + j]
                a[k + j] = (u + t) % MOD
                a[k + j + m // 2] = (u - t) % MOD
                w = (w * w_m) % MOD
    if invert:
        inv_n = pow(n, MOD - 2, MOD)
        for i in range(n):
            a[i] = (a[i] * inv_n) % MOD
    return a

def convolution(a, b):
    len_a = len(a)
    len_b = len(b)
    n = 1
    while n < len_a + len_b - 1:
        n <<= 1
    a += [0] * (n - len_a)
    b += [0] * (n - len_b)
    a = ntt(a)
    b = ntt(b)
    c = [(a[i] * b[i]) % MOD for i in range(n)]
    c = ntt(c, invert=True)
    return c[:len_a + len_b - 1]

def multiply(a, b):
    return convolution(a, b)

def product_helper(a_list, l, r):
    if l == r:
        return [1, (-a_list[l]) % MOD]
    mid = (l + r) // 2
    left = product_helper(a_list, l, mid)
    right = product_helper(a_list, mid + 1, r)
    return multiply(left, right)

def compute_P(A):
    if not A:
        return [1]
    return product_helper(A, 0, len(A) - 1)

def compute_derivative(P):
    n = len(P)
    if n == 0:
        return []
    P_deriv = [0] * (n - 1)
    for i in range(1, n):
        P_deriv[i - 1] = (i * P[i]) % MOD
    return P_deriv

def inverse(P, m):
    if not P:
        return []
    MOD = 998244353
    g = [0] * m
    g[0] = pow(P[0], MOD - 2, MOD)
    current_len = 1
    while current_len < m:
        next_len = min(current_len * 2, m)
        f = P[:next_len] + [0] * (next_len - len(P[:next_len]))
        product = multiply(f[:current_len * 2], g[:current_len])
        product = product[:next_len]
        two_minus = [(-product[i]) % MOD for i in range(next_len)]
        two_minus[0] = (two_minus[0] + 2) % MOD
        increment = multiply(g[:current_len], two_minus[:current_len * 2])
        increment = increment[:next_len]
        for i in range(current_len, next_len):
            g[i] = increment[i]
        current_len = next_len
    return g[:m]

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    M = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr + N]))
    ptr += N

    if N == 0:
        print(' '.join(['0'] * M))
        return

    P = compute_P(A)
    P_deriv = compute_derivative(P)
    Q = inverse(P, M)
    len_P_deriv = len(P_deriv)
    P_deriv_truncated = P_deriv[:M] + [0] * (M - len_P_deriv) if len_P_deriv < M else P_deriv[:M]
    R = multiply(P_deriv_truncated, Q)
    R = R[:M]
    R = [(-x) % MOD for x in R]
    print(' '.join(map(str, R)))

if __name__ == "__main__":
    main()
0