結果
問題 | No.1718 Random Squirrel |
ユーザー |
|
提出日時 | 2021-08-05 17:48:49 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 521 ms / 2,000 ms |
コード長 | 2,566 bytes |
コンパイル時間 | 219 ms |
コンパイル使用メモリ | 82,192 KB |
実行使用メモリ | 120,740 KB |
最終ジャッジ日時 | 2024-09-23 03:47:35 |
合計ジャッジ時間 | 8,044 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 31 |
ソースコード
INF = 1001001001N, K = map(int, input().split())adj = [[] for _ in range(N)]for _ in range(N - 1):u, v = map(int, input().split())u -= 1; v -= 1adj[u].append(v)adj[v].append(u)Ds = list(map(int, input().split()))for i in range(K):Ds[i] -= 1in_D = [False] * Nfor D in Ds:in_D[D] = True# 最小シュタイナー木みたいなのを作る。reserved = [False] * Nroot = Ds[0]parent = [-1] * Nstack = [root]while stack:v = stack.pop()if in_D[v]:reserved[v] = Truewhile parent[v] != -1 and not reserved[parent[v]]:v = parent[v]reserved[v] = Truefor nv in adj[v]:if nv == parent[v]:continueparent[nv] = vstack.append(nv)# print(reserved)# 各頂点から、# 最小シュタイナー木みたいなのにたどりつくための距離と、# たどりつく先の頂点を求める。dist_Steiner = [INF] * Nnearest = [-1] * Nfor v in range(N):if reserved[v]:dist_Steiner[v] = 0nearest[v] = vstack = []for v in range(N):if reserved[v]:for nv in adj[v]:if not reserved[nv]:stack.append(v)breakwhile stack:v = stack.pop()for nv in adj[v]:if dist_Steiner[nv] > dist_Steiner[v] + 1:dist_Steiner[nv] = dist_Steiner[v] + 1nearest[nv] = nearest[v]stack.append(nv)# シュタイナー木を一周するときに、帰りの分をさぼれる。# そのさぼった分の最大化をする。def dist_from_root_in_reserved(root):dist_diameter = [INF] * Ndist_diameter[root] = 0stack = [root]while stack:v = stack.pop()for nv in adj[v]:if reserved[nv] and dist_diameter[nv] > dist_diameter[v] + 1:dist_diameter[nv] = dist_diameter[v] + 1stack.append(nv)for i in range(N):if dist_diameter[i] == INF:dist_diameter[i] = -1return dist_diameterdist_0 = dist_from_root_in_reserved(Ds[0])end1 = max(range(N), key = lambda i: dist_0[i])dist_end1 = dist_from_root_in_reserved(end1)end2 = max(range(N), key = lambda i: dist_end1[i])dist_end2 = dist_from_root_in_reserved(end2)max_dist = [max(e1, e2) for e1, e2 in zip(dist_end1, dist_end2)]# print(max_dist)# print(dist_0)# print(dist_end1)# print(dist_end2)loop_len = (sum(reserved) - 1) * 2for v in range(N):s = nearest[v]answer = dist_Steiner[v] + loop_len - max_dist[s]print(answer)