結果
| 問題 |
No.901 K-ary εxtrεεmε
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-26 15:46:24 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 524 ms / 3,000 ms |
| コード長 | 3,177 bytes |
| コンパイル時間 | 275 ms |
| コンパイル使用メモリ | 82,588 KB |
| 実行使用メモリ | 135,936 KB |
| 最終ジャッジ日時 | 2025-03-26 15:47:00 |
| 合計ジャッジ時間 | 15,194 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 29 |
ソースコード
import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr +=1
# Build adjacency list
adj = [[] for _ in range(N)]
for _ in range(N-1):
u = int(input[ptr])
v = int(input[ptr+1])
w = int(input[ptr+2])
adj[u].append( (v, w) )
adj[v].append( (u, w) )
ptr +=3
# DFS to compute in_time, out_time, parent, depth_edges, depth_sum
root = 0
in_time = [0]*N
out_time = [0]*N
parent = [-1]*N
depth_edges = [0]*N # depth in terms of number of edges
depth_sum = [0]*N # depth in terms of sum of weights
time = 0
stack = [(root, -1, 0, 0, False)] # node, parent, depth_edges, depth_sum, visited
while stack:
node, p, de, ds, visited = stack.pop()
if not visited:
parent[node] = p
depth_edges[node] = de
depth_sum[node] = ds
in_time[node] = time
time +=1
# Push back with visited=True to process out_time later
stack.append( (node, p, de, ds, True) )
# Push children in reverse order to process them in order
# Iterate and collect children (excluding parent)
children = []
for (v, w) in adj[node]:
if v != p:
children.append( (v, w) )
# Reverse to maintain order
for child, w in reversed(children):
stack.append( (child, node, de+1, ds + w, False) )
else:
out_time[node] = time
time +=1
# Precompute binary lifting for LCA (based on depth_edges)
max_level = 20
up = [ [ -1 ] * N for _ in range(max_level) ]
# up[0] is the immediate parent
for i in range(N):
up[0][i] = parent[i]
for k in range(1, max_level):
for i in range(N):
if up[k-1][i] != -1:
up[k][i] = up[k-1][ up[k-1][i] ]
else:
up[k][i] = -1
def lca(u, v):
if depth_edges[u] < depth_edges[v]:
u, v = v, u
# Bring u up to the depth of v
for k in range(max_level-1, -1, -1):
if depth_edges[u] - (1 << k) >= depth_edges[v]:
u = up[k][u]
if u == v:
return u
for k in range(max_level-1, -1, -1):
if up[k][u] != -1 and up[k][u] != up[k][v]:
u = up[k][u]
v = up[k][v]
return up[0][u]
Q = int(input[ptr])
ptr +=1
for _ in range(Q):
k = int(input[ptr])
ptr +=1
nodes = list(map(int, input[ptr:ptr+k]))
ptr +=k
if k ==1:
print(0)
continue
# Sort nodes by in_time
nodes_sorted = sorted(nodes, key=lambda x: in_time[x])
total = 0
for i in range(k):
u = nodes_sorted[i]
v = nodes_sorted[ (i+1) %k ]
ancestor = lca(u, v)
distance = depth_sum[u] + depth_sum[v] - 2 * depth_sum[ancestor]
total += distance
print(total //2)
if __name__ == "__main__":
main()
lam6er