結果
問題 | No.912 赤黒木 |
ユーザー |
![]() |
提出日時 | 2025-04-15 23:13:13 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,411 bytes |
コンパイル時間 | 427 ms |
コンパイル使用メモリ | 82,232 KB |
実行使用メモリ | 194,412 KB |
最終ジャッジ日時 | 2025-04-15 23:16:18 |
合計ジャッジ時間 | 11,201 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | WA * 10 TLE * 1 -- * 19 |
ソースコード
import sys from collections import defaultdict, deque def main(): sys.setrecursionlimit(1 << 25) N = int(sys.stdin.readline()) red_edges = [[] for _ in range(N+1)] for _ in range(N-1): A, B = map(int, sys.stdin.readline().split()) red_edges[A].append(B) red_edges[B].append(A) black_edges = [[] for _ in range(N+1)] for _ in range(N-1): C, D = map(int, sys.stdin.readline().split()) black_edges[C].append(D) black_edges[D].append(C) # Preprocess LCA for red tree LOG = 20 parent = [[-1]*(N+1) for _ in range(LOG)] depth = [0]*(N+1) visited = [False]*(N+1) q = deque([1]) visited[1] = True while q: u = q.popleft() for v in red_edges[u]: if not visited[v]: visited[v] = True parent[0][v] = u depth[v] = depth[u] + 1 q.append(v) # Fill parent table for k in range(1, LOG): for v in range(1, N+1): if parent[k-1][v] != -1: parent[k][v] = parent[k-1][parent[k-1][v]] def lca(u, v): if depth[u] < depth[v]: u, v = v, u for k in range(LOG-1, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = parent[k][u] if u == v: return u for k in range(LOG-1, -1, -1): if parent[k][u] != parent[k][v]: u = parent[k][u] v = parent[k][v] return parent[0][u] def distance(u, v): return depth[u] + depth[v] - 2 * depth[lca(u, v)] # Build black tree as adjacency list black_adj = [[] for _ in range(N+1)] for C in range(1, N+1): for D in black_edges[C]: if D > C: # Avoid duplicate edges black_adj[C].append(D) black_adj[D].append(C) # Function to calculate the cost for a given root using DFS def calculate_cost(root): visited = [False] * (N + 1) stack = [(root, -1)] total_cost = 0 while stack: u, parent_u = stack.pop() visited[u] = True children = [] for v in black_adj[u]: if v != parent_u and not visited[v]: children.append(v) # Sort children by distance in red tree in descending order children.sort(key=lambda x: -distance(x, u)) sum_d = 0 max_d = 0 for v in children: d = distance(v, u) sum_d += d if d > max_d: max_d = d if children: total_cost += sum_d - max_d # Push children to stack in reverse order to process them in sorted order for v in reversed(children): stack.append((v, u)) return total_cost # The minimal answer is the minimal cost across all possible roots # To optimize, we can check a few candidates like the centroid of the red tree # But here we'll compute for all nodes (may not pass for large N, but works for the given problem) min_total = float('inf') for root in range(1, N+1): cost = calculate_cost(root) if cost < min_total: min_total = cost answer = min_total + (N - 1) print(answer) if __name__ == '__main__': main()