結果

問題 No.2005 Sum of Power Sums
ユーザー lam6er
提出日時 2025-03-20 20:53:41
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,926 bytes
コンパイル時間 205 ms
コンパイル使用メモリ 82,860 KB
実行使用メモリ 288,200 KB
最終ジャッジ日時 2025-03-20 20:54:46
合計ジャッジ時間 9,947 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 5
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, M = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    K_list = list(map(int, input[ptr:ptr+N]))
    ptr += N

    # Precompute Stirling numbers of the second kind up to 5000
    max_d_stirling = 5000
    stirling = [[0] * (max_d_stirling + 1) for _ in range(max_d_stirling + 1)]
    stirling[0][0] = 1
    for d in range(1, max_d_stirling + 1):
        for k in range(1, d + 1):
            stirling[d][k] = (k * stirling[d - 1][k] + stirling[d - 1][k - 1]) % MOD

    max_factorial = 2 * 10**5 + 5000
    fact = [1] * (max_factorial + 1)
    for i in range(1, max_factorial + 1):
        fact[i] = fact[i - 1] * i % MOD

    inv_fact = [1] * (max_factorial + 1)
    inv_fact[max_factorial] = pow(fact[max_factorial], MOD - 2, MOD)
    for i in range(max_factorial - 1, -1, -1):
        inv_fact[i] = inv_fact[i + 1] * (i + 1) % MOD

    X = (M + N) % MOD
    max_k = max(K_list) if K_list else 0
    max_r = N + max_k

    if max_r > max_factorial:
        print(0)
        return

    # Compute falling factorials
    falling = [0] * (max_r + 1)
    falling[0] = 1
    for r in range(1, max_r + 1):
        falling[r] = falling[r - 1] * (X - (r - 1)) % MOD

    total = 0
    for K in K_list:
        current = 0
        if K == 0:
            continue
        for k in range(1, K + 1):
            r = N + k
            if r > max_r:
                continue
            s = stirling[K][k]
            if s == 0:
                continue
            numerator = falling[r]
            if numerator == 0:
                continue
            denom = inv_fact[r]
            term = numerator * denom % MOD
            term = term * s % MOD
            term = term * fact[k] % MOD
            current = (current + term) % MOD
        total = (total + current) % MOD

    print(total % MOD)

if __name__ == '__main__':
    main()
0