結果
問題 |
No.901 K-ary εxtrεεmε
|
ユーザー |
![]() |
提出日時 | 2025-03-26 15:46:24 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 524 ms / 3,000 ms |
コード長 | 3,177 bytes |
コンパイル時間 | 275 ms |
コンパイル使用メモリ | 82,588 KB |
実行使用メモリ | 135,936 KB |
最終ジャッジ日時 | 2025-03-26 15:47:00 |
合計ジャッジ時間 | 15,194 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 29 |
ソースコード
import sys from sys import stdin sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr +=1 # Build adjacency list adj = [[] for _ in range(N)] for _ in range(N-1): u = int(input[ptr]) v = int(input[ptr+1]) w = int(input[ptr+2]) adj[u].append( (v, w) ) adj[v].append( (u, w) ) ptr +=3 # DFS to compute in_time, out_time, parent, depth_edges, depth_sum root = 0 in_time = [0]*N out_time = [0]*N parent = [-1]*N depth_edges = [0]*N # depth in terms of number of edges depth_sum = [0]*N # depth in terms of sum of weights time = 0 stack = [(root, -1, 0, 0, False)] # node, parent, depth_edges, depth_sum, visited while stack: node, p, de, ds, visited = stack.pop() if not visited: parent[node] = p depth_edges[node] = de depth_sum[node] = ds in_time[node] = time time +=1 # Push back with visited=True to process out_time later stack.append( (node, p, de, ds, True) ) # Push children in reverse order to process them in order # Iterate and collect children (excluding parent) children = [] for (v, w) in adj[node]: if v != p: children.append( (v, w) ) # Reverse to maintain order for child, w in reversed(children): stack.append( (child, node, de+1, ds + w, False) ) else: out_time[node] = time time +=1 # Precompute binary lifting for LCA (based on depth_edges) max_level = 20 up = [ [ -1 ] * N for _ in range(max_level) ] # up[0] is the immediate parent for i in range(N): up[0][i] = parent[i] for k in range(1, max_level): for i in range(N): if up[k-1][i] != -1: up[k][i] = up[k-1][ up[k-1][i] ] else: up[k][i] = -1 def lca(u, v): if depth_edges[u] < depth_edges[v]: u, v = v, u # Bring u up to the depth of v for k in range(max_level-1, -1, -1): if depth_edges[u] - (1 << k) >= depth_edges[v]: u = up[k][u] if u == v: return u for k in range(max_level-1, -1, -1): if up[k][u] != -1 and up[k][u] != up[k][v]: u = up[k][u] v = up[k][v] return up[0][u] Q = int(input[ptr]) ptr +=1 for _ in range(Q): k = int(input[ptr]) ptr +=1 nodes = list(map(int, input[ptr:ptr+k])) ptr +=k if k ==1: print(0) continue # Sort nodes by in_time nodes_sorted = sorted(nodes, key=lambda x: in_time[x]) total = 0 for i in range(k): u = nodes_sorted[i] v = nodes_sorted[ (i+1) %k ] ancestor = lca(u, v) distance = depth_sum[u] + depth_sum[v] - 2 * depth_sum[ancestor] total += distance print(total //2) if __name__ == "__main__": main()