結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー |
![]() |
提出日時 | 2025-03-26 15:46:24 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 524 ms / 3,000 ms |
コード長 | 3,177 bytes |
コンパイル時間 | 275 ms |
コンパイル使用メモリ | 82,588 KB |
実行使用メモリ | 135,936 KB |
最終ジャッジ日時 | 2025-03-26 15:47:00 |
合計ジャッジ時間 | 15,194 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 29 |
ソースコード
import sysfrom sys import stdinsys.setrecursionlimit(1 << 25)def main():input = sys.stdin.read().split()ptr = 0N = int(input[ptr])ptr +=1# Build adjacency listadj = [[] for _ in range(N)]for _ in range(N-1):u = int(input[ptr])v = int(input[ptr+1])w = int(input[ptr+2])adj[u].append( (v, w) )adj[v].append( (u, w) )ptr +=3# DFS to compute in_time, out_time, parent, depth_edges, depth_sumroot = 0in_time = [0]*Nout_time = [0]*Nparent = [-1]*Ndepth_edges = [0]*N # depth in terms of number of edgesdepth_sum = [0]*N # depth in terms of sum of weightstime = 0stack = [(root, -1, 0, 0, False)] # node, parent, depth_edges, depth_sum, visitedwhile stack:node, p, de, ds, visited = stack.pop()if not visited:parent[node] = pdepth_edges[node] = dedepth_sum[node] = dsin_time[node] = timetime +=1# Push back with visited=True to process out_time laterstack.append( (node, p, de, ds, True) )# Push children in reverse order to process them in order# Iterate and collect children (excluding parent)children = []for (v, w) in adj[node]:if v != p:children.append( (v, w) )# Reverse to maintain orderfor child, w in reversed(children):stack.append( (child, node, de+1, ds + w, False) )else:out_time[node] = timetime +=1# Precompute binary lifting for LCA (based on depth_edges)max_level = 20up = [ [ -1 ] * N for _ in range(max_level) ]# up[0] is the immediate parentfor i in range(N):up[0][i] = parent[i]for k in range(1, max_level):for i in range(N):if up[k-1][i] != -1:up[k][i] = up[k-1][ up[k-1][i] ]else:up[k][i] = -1def lca(u, v):if depth_edges[u] < depth_edges[v]:u, v = v, u# Bring u up to the depth of vfor k in range(max_level-1, -1, -1):if depth_edges[u] - (1 << k) >= depth_edges[v]:u = up[k][u]if u == v:return ufor k in range(max_level-1, -1, -1):if up[k][u] != -1 and up[k][u] != up[k][v]:u = up[k][u]v = up[k][v]return up[0][u]Q = int(input[ptr])ptr +=1for _ in range(Q):k = int(input[ptr])ptr +=1nodes = list(map(int, input[ptr:ptr+k]))ptr +=kif k ==1:print(0)continue# Sort nodes by in_timenodes_sorted = sorted(nodes, key=lambda x: in_time[x])total = 0for i in range(k):u = nodes_sorted[i]v = nodes_sorted[ (i+1) %k ]ancestor = lca(u, v)distance = depth_sum[u] + depth_sum[v] - 2 * depth_sum[ancestor]total += distanceprint(total //2)if __name__ == "__main__":main()