結果

問題 No.2004 Incremental Coins
ユーザー gew1fw
提出日時 2025-06-12 13:03:36
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 1,860 bytes
コンパイル時間 267 ms
コンパイル使用メモリ 82,644 KB
実行使用メモリ 152,380 KB
最終ジャッジ日時 2025-06-12 13:09:02
合計ジャッジ時間 9,126 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 14 RE * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    N = int(data[idx])
    idx += 1
    K = int(data[idx])
    idx += 1
    A = list(map(int, data[idx:idx+N+1]))
    idx += N+1
    P = list(map(int, data[idx:idx+N]))
    
    children = [[] for _ in range(N+1)]
    for j in range(1, N+1):
        p = P[j-1]
        children[p].append(j)
    
    depth = [0] * (N+1)
    stack = [(0, 0)]
    while stack:
        u, d = stack.pop()
        depth[u] = d
        for v in children[u]:
            stack.append((v, d+1))
    
    max_depth = max(depth)
    
    max_k = max_depth
    fact = [1] * (max_k + 2)
    for i in range(1, max_k + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_k + 2)
    inv_fact[max_k] = pow(fact[max_k], MOD-2, MOD)
    for i in range(max_k-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    C = [0] * (max_k + 1)
    if K == 0:
        C = [0] * (max_k + 1)
        C[0] = 1
    else:
        C[0] = 1
        for k in range(1, max_k + 1):
            term = C[k-1] * (K - k + 1) % MOD
            term = term * inv_fact[k] % MOD
            term = term * fact[k-1] % MOD
            C[k] = term
    
    result = [0] * (N+1)
    
    from collections import defaultdict
    
    def dfs(u):
        cnt = defaultdict(int)
        cnt[0] = A[u]
        for v in children[u]:
            child_cnt = dfs(v)
            for k in child_cnt:
                new_k = k + 1
                cnt[new_k] = (cnt[new_k] + child_cnt[k]) % MOD
        res = 0
        for k in cnt:
            if k > max_k:
                continue
            res = (res + cnt[k] * C[k]) % MOD
        result[u] = res % MOD
        return cnt
    
    dfs(0)
    
    for i in range(N+1):
        print(result[i] % MOD)

if __name__ == '__main__':
    main()
0