結果
| 問題 |
No.196 典型DP (1)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:45:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 329 ms / 2,000 ms |
| コード長 | 2,488 bytes |
| コンパイル時間 | 239 ms |
| コンパイル使用メモリ | 82,152 KB |
| 実行使用メモリ | 109,796 KB |
| 最終ジャッジ日時 | 2025-03-31 17:46:13 |
| 合計ジャッジ時間 | 7,412 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 41 |
ソースコード
import sys
from collections import deque
sys.setrecursionlimit(1 << 25)
def main():
MOD = 10**9 + 7
n, K = map(int, sys.stdin.readline().split())
edges = [[] for _ in range(n)]
for _ in range(n-1):
a, b = map(int, sys.stdin.readline().split())
edges[a].append(b)
edges[b].append(a)
# Build children lists using BFS from root 0
children = [[] for _ in range(n)]
visited = [False] * n
q = deque([0])
visited[0] = True
while q:
u = q.popleft()
for v in edges[u]:
if not visited[v]:
visited[v] = True
children[u].append(v)
q.append(v)
# Compute subtree sizes using post-order traversal
size = [1] * n
post_order = []
stack = [(0, False)]
while stack:
node, processed = stack.pop()
if processed:
post_order.append(node)
for child in children[node]:
size[node] += size[child]
else:
stack.append((node, True))
# Reverse to process children in order
for child in reversed(children[node]):
stack.append((child, False))
# Initialize DP arrays
dp = [[0] * (K + 1) for _ in range(n)]
for u in post_order:
current = [0] * (K + 1)
current[0] = 1 # Base case: no nodes selected
for v in children[u]:
# Create a new temporary array for the convolution
new_current = [0] * (K + 1)
# Get non-zero indices from current and dp[v]
non_zero_current = [i for i in range(K+1) if current[i] != 0]
non_zero_v = [j for j in range(K+1) if dp[v][j] != 0]
for i in non_zero_current:
for j in non_zero_v:
if i + j > K:
continue
new_current[i + j] = (new_current[i + j] + current[i] * dp[v][j]) % MOD
current = new_current
# Calculate the product of dp[v][0] for all children v
product = 1
for v in children[u]:
product = (product * dp[v][0]) % MOD
if product == 0:
break
s = size[u]
if s <= K:
current[s] = (current[s] + product) % MOD
# Update dp[u]
for k in range(K + 1):
dp[u][k] = current[k]
print(dp[0][K] % MOD)
if __name__ == '__main__':
main()
lam6er