結果
| 問題 |
No.1718 Random Squirrel
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:45:10 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 771 ms / 2,000 ms |
| コード長 | 4,339 bytes |
| コンパイル時間 | 365 ms |
| コンパイル使用メモリ | 82,720 KB |
| 実行使用メモリ | 141,332 KB |
| 最終ジャッジ日時 | 2025-03-31 17:46:14 |
| 合計ジャッジ時間 | 12,167 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 31 |
ソースコード
import sys
from collections import deque
def main():
input = sys.stdin.read
data = input().split()
idx = 0
N, K = int(data[idx]), int(data[idx+1])
idx +=2
edges = [[] for _ in range(N+1)]
for _ in range(N-1):
u = int(data[idx])
v = int(data[idx+1])
edges[u].append(v)
edges[v].append(u)
idx +=2
D = list(map(int, data[idx:idx+K]))
idx +=K
is_D = [False]*(N+1)
for d in D:
is_D[d] = True
root = 1
parent = [0]*(N+1)
children = [[] for _ in range(N+1)]
visited = [False]*(N+1)
q = deque([root])
visited[root] = True
while q:
u = q.popleft()
for v in edges[u]:
if not visited[v] and v != parent[u]:
parent[v] = u
children[u].append(v)
visited[v] = True
q.append(v)
cnt = [0]*(N+1)
stack = [(root, False)]
while stack:
u, processed = stack.pop()
if processed:
total = 0
if is_D[u]:
total = 1
for v in children[u]:
total += cnt[v]
cnt[u] = total
else:
stack.append((u, True))
for v in reversed(children[u]):
stack.append((v, False))
in_T = [False]*(N+1)
for u in range(1, N+1):
if is_D[u]:
in_T[u] = True
else:
sum_children = sum(cnt[v] for v in children[u])
others = K - sum_children
if sum_children > 0 and others > 0:
in_T[u] = True
else:
cnt_pos = 0
for v in children[u]:
if cnt[v] > 0:
cnt_pos += 1
if cnt_pos >= 2:
break
if cnt_pos >= 2:
in_T[u] = True
sum_edges_T = 0
adj_T = [[] for _ in range(N+1)]
for u in range(1, N+1):
for v in edges[u]:
if v > u and in_T[u] and in_T[v]:
sum_edges_T += 1
adj_T[u].append(v)
adj_T[v].append(u)
t_nodes = [u for u in range(1, N+1) if in_T[u]]
if not t_nodes:
print('\n'.join('0' for _ in range(N)))
return
initial_node = t_nodes[0]
def bfs_farthest(start, adj):
dist = [-1]*(N+1)
q = deque([start])
dist[start] = 0
max_dist = 0
far_node = start
while q:
u = q.popleft()
for v in adj[u]:
if dist[v] == -1:
dist[v] = dist[u] + 1
q.append(v)
if in_T[v] and dist[v] > max_dist:
max_dist = dist[v]
far_node = v
return far_node, max_dist
u, _ = bfs_farthest(initial_node, adj_T)
v, diam_len = bfs_farthest(u, adj_T)
def compute_dist(start, adj):
dist = [-1]*(N+1)
q = deque([start])
dist[start] = 0
while q:
u = q.popleft()
for v_node in adj[u]:
if dist[v_node] == -1:
dist[v_node] = dist[u] + 1
q.append(v_node)
return dist
dist_u = compute_dist(u, adj_T)
dist_v = compute_dist(v, adj_T)
d_to_T = [-1]*(N+1)
Y_of = [-1]*(N+1)
q = deque()
for x in range(1, N+1):
if in_T[x]:
d_to_T[x] = 0
Y_of[x] = x
q.append(x)
visited = [False]*(N+1)
for x in q:
visited[x] = True
while q:
u_node = q.popleft()
for v_node in edges[u_node]:
if not visited[v_node]:
visited[v_node] = True
d_to_T[v_node] = d_to_T[u_node] + 1
Y_of[v_node] = Y_of[u_node]
q.append(v_node)
output = []
for x in range(1, N+1):
Y = Y_of[x]
d = d_to_T[x]
if Y == -1:
output.append(0)
else:
du = dist_u[Y]
dv = dist_v[Y]
max_dist = max(du, dv)
ans = d + 2 * sum_edges_T - max_dist
output.append(ans)
print('\n'.join(map(str, output)))
if __name__ == '__main__':
main()
lam6er