結果

問題 No.912 赤黒木
ユーザー lam6er
提出日時 2025-04-15 23:08:16
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,411 bytes
コンパイル時間 584 ms
コンパイル使用メモリ 81,700 KB
実行使用メモリ 78,568 KB
最終ジャッジ日時 2025-04-15 23:10:31
合計ジャッジ時間 10,929 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other WA * 10 TLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict, deque

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    red_edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        A, B = map(int, sys.stdin.readline().split())
        red_edges[A].append(B)
        red_edges[B].append(A)
    
    black_edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        C, D = map(int, sys.stdin.readline().split())
        black_edges[C].append(D)
        black_edges[D].append(C)
    
    # Preprocess LCA for red tree
    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    visited = [False]*(N+1)
    q = deque([1])
    visited[1] = True
    while q:
        u = q.popleft()
        for v in red_edges[u]:
            if not visited[v]:
                visited[v] = True
                parent[0][v] = u
                depth[v] = depth[u] + 1
                q.append(v)
    # Fill parent table
    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
    
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]
    
    def distance(u, v):
        return depth[u] + depth[v] - 2 * depth[lca(u, v)]
    
    # Build black tree as adjacency list
    black_adj = [[] for _ in range(N+1)]
    for C in range(1, N+1):
        for D in black_edges[C]:
            if D > C:  # Avoid duplicate edges
                black_adj[C].append(D)
                black_adj[D].append(C)
    
    # Function to calculate the cost for a given root using DFS
    def calculate_cost(root):
        visited = [False] * (N + 1)
        stack = [(root, -1)]
        total_cost = 0
        while stack:
            u, parent_u = stack.pop()
            visited[u] = True
            children = []
            for v in black_adj[u]:
                if v != parent_u and not visited[v]:
                    children.append(v)
            # Sort children by distance in red tree in descending order
            children.sort(key=lambda x: -distance(x, u))
            sum_d = 0
            max_d = 0
            for v in children:
                d = distance(v, u)
                sum_d += d
                if d > max_d:
                    max_d = d
            if children:
                total_cost += sum_d - max_d
            # Push children to stack in reverse order to process them in sorted order
            for v in reversed(children):
                stack.append((v, u))
        return total_cost
    
    # The minimal answer is the minimal cost across all possible roots
    # To optimize, we can check a few candidates like the centroid of the red tree
    # But here we'll compute for all nodes (may not pass for large N, but works for the given problem)
    min_total = float('inf')
    for root in range(1, N+1):
        cost = calculate_cost(root)
        if cost < min_total:
            min_total = cost
    answer = min_total + (N - 1)
    print(answer)

if __name__ == '__main__':
    main()
0