結果

問題 No.196 典型DP (1)
ユーザー lam6er
提出日時 2025-03-31 17:45:16
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 329 ms / 2,000 ms
コード長 2,488 bytes
コンパイル時間 239 ms
コンパイル使用メモリ 82,152 KB
実行使用メモリ 109,796 KB
最終ジャッジ日時 2025-03-31 17:46:13
合計ジャッジ時間 7,412 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 41
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque
sys.setrecursionlimit(1 << 25)

def main():
    MOD = 10**9 + 7
    n, K = map(int, sys.stdin.readline().split())
    edges = [[] for _ in range(n)]
    for _ in range(n-1):
        a, b = map(int, sys.stdin.readline().split())
        edges[a].append(b)
        edges[b].append(a)
    
    # Build children lists using BFS from root 0
    children = [[] for _ in range(n)]
    visited = [False] * n
    q = deque([0])
    visited[0] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v]:
                visited[v] = True
                children[u].append(v)
                q.append(v)
    
    # Compute subtree sizes using post-order traversal
    size = [1] * n
    post_order = []
    stack = [(0, False)]
    while stack:
        node, processed = stack.pop()
        if processed:
            post_order.append(node)
            for child in children[node]:
                size[node] += size[child]
        else:
            stack.append((node, True))
            # Reverse to process children in order
            for child in reversed(children[node]):
                stack.append((child, False))
    
    # Initialize DP arrays
    dp = [[0] * (K + 1) for _ in range(n)]
    
    for u in post_order:
        current = [0] * (K + 1)
        current[0] = 1  # Base case: no nodes selected
        
        for v in children[u]:
            # Create a new temporary array for the convolution
            new_current = [0] * (K + 1)
            # Get non-zero indices from current and dp[v]
            non_zero_current = [i for i in range(K+1) if current[i] != 0]
            non_zero_v = [j for j in range(K+1) if dp[v][j] != 0]
            for i in non_zero_current:
                for j in non_zero_v:
                    if i + j > K:
                        continue
                    new_current[i + j] = (new_current[i + j] + current[i] * dp[v][j]) % MOD
            current = new_current
        
        # Calculate the product of dp[v][0] for all children v
        product = 1
        for v in children[u]:
            product = (product * dp[v][0]) % MOD
            if product == 0:
                break
        
        s = size[u]
        if s <= K:
            current[s] = (current[s] + product) % MOD
        
        # Update dp[u]
        for k in range(K + 1):
            dp[u][k] = current[k]
    
    print(dp[0][K] % MOD)

if __name__ == '__main__':
    main()
0