結果
問題 |
No.901 K-ary εxtrεεmε
|
ユーザー |
![]() |
提出日時 | 2025-03-31 17:20:40 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 512 ms / 3,000 ms |
コード長 | 2,499 bytes |
コンパイル時間 | 157 ms |
コンパイル使用メモリ | 82,656 KB |
実行使用メモリ | 136,020 KB |
最終ジャッジ日時 | 2025-03-31 17:21:23 |
合計ジャッジ時間 | 14,609 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 29 |
ソースコード
import sys sys.setrecursionlimit(1 << 25) def main(): data = sys.stdin.read().split() ptr = 0 N = int(data[ptr]) ptr += 1 adj = [[] for _ in range(N)] for _ in range(N-1): u = int(data[ptr]) v = int(data[ptr+1]) w = int(data[ptr+2]) adj[u].append((v, w)) adj[v].append((u, w)) ptr +=3 # Preprocessing for in_time, depth, parent, dist in_time = [0] * N depth = [0] * N parent = [-1] * N dist = [0] * N visited = [False] * N stack = [(0, -1, 0)] # (node, parent, weight_from_parent) time = 0 while stack: u, p, w = stack.pop() if visited[u]: continue visited[u] = True in_time[u] = time time += 1 if p != -1: dist[u] = dist[p] + w depth[u] = depth[p] + 1 else: dist[u] = 0 depth[u] = 0 parent[u] = p # Push children in reverse order to process in correct DFS order for v, edge_w in reversed(adj[u]): if not visited[v] and v != p: stack.append((v, u, edge_w)) # Preprocess LCA binary lifting table LOG = 20 up = [[-1]*N for _ in range(LOG)] for i in range(N): up[0][i] = parent[i] for k in range(1, LOG): for i in range(N): if up[k-1][i] != -1: up[k][i] = up[k-1][ up[k-1][i] ] def lca(u, v): if depth[u] < depth[v]: u, v = v, u # Bring u to the depth of v for k in range(LOG-1, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = up[k][u] if u == v: return u for k in range(LOG-1, -1, -1): if up[k][u] != up[k][v]: u = up[k][u] v = up[k][v] return up[0][u] def get_distance(u, v): ancestor = lca(u, v) return dist[u] + dist[v] - 2 * dist[ancestor] Q = int(data[ptr]) ptr +=1 for _ in range(Q): k_i = int(data[ptr]) ptr +=1 x_list = list(map(int, data[ptr:ptr + k_i])) ptr +=k_i if k_i == 1: print(0) continue # Sort by in_time x_list.sort(key=lambda x: in_time[x]) total = 0 K = len(x_list) for i in range(K): u = x_list[i] v = x_list[(i+1)%K] total += get_distance(u, v) print(total // 2) if __name__ == "__main__": main()