結果

問題 No.2005 Sum of Power Sums
ユーザー lam6er
提出日時 2025-04-16 16:29:20
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,730 bytes
コンパイル時間 233 ms
コンパイル使用メモリ 82,144 KB
実行使用メモリ 160,476 KB
最終ジャッジ日時 2025-04-16 16:30:43
合計ジャッジ時間 5,484 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 5
権限があれば一括ダウンロードができます

ソースコード

diff #

mod = 998244353

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    N = int(data[0])
    M = int(data[1])
    K_list = list(map(int, data[2:2+N]))
    
    if not K_list:
        print(0)
        return
    
    max_K = max(K_list)
    max_r = N + max_K
    
    # Precompute Stirling numbers of the second kind up to max_K
    stirling = [[0] * (max_K + 1) for _ in range(max_K + 1)]
    stirling[0][0] = 1
    for k in range(1, max_K + 1):
        for m in range(1, k + 1):
            stirling[k][m] = (m * stirling[k-1][m] + stirling[k-1][m-1]) % mod
    
    # Precompute factorials and inverse factorials up to max_r
    fact = [1] * (max_r + 1)
    for i in range(1, max_r + 1):
        fact[i] = fact[i-1] * i % mod
    
    inv_fact = [1] * (max_r + 1)
    inv_fact[max_r] = pow(fact[max_r], mod-2, mod)
    for i in range(max_r - 1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % mod
    
    # Compute prefix array
    M_mod = M % mod
    prefix = [0] * (max_r + 1)
    prefix[0] = 1
    for r in range(1, max_r + 1):
        term = (M_mod + N - (r - 1)) % mod
        prefix[r] = prefix[r-1] * term % mod
    
    # Calculate the total sum
    total = 0
    for K in K_list:
        current = 0
        for m in range(0, K + 1):
            s = stirling[K][m]
            if s == 0:
                continue
            r = N + m
            if r > max_r:
                comb = 0
            else:
                comb = prefix[r] * inv_fact[r] % mod
            term = s * fact[m] % mod
            term = term * comb % mod
            current = (current + term) % mod
        total = (total + current) % mod
    
    print(total)

if __name__ == '__main__':
    main()
0