結果

問題 No.912 赤黒木
ユーザー lam6er
提出日時 2025-04-16 16:05:37
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,123 bytes
コンパイル時間 255 ms
コンパイル使用メモリ 82,456 KB
実行使用メモリ 82,328 KB
最終ジャッジ日時 2025-04-16 16:13:21
合計ジャッジ時間 7,114 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other WA * 10 TLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict, deque

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    red_edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = map(int, sys.stdin.readline().split())
        red_edges[a].append(b)
        red_edges[b].append(a)
    black_edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        c, d = map(int, sys.stdin.readline().split())
        black_edges[c].append(d)
        black_edges[d].append(c)

    # Precompute distances in red tree using BFS for each node
    def bfs(start):
        dist = [-1] * (N+1)
        q = deque([start])
        dist[start] = 0
        while q:
            u = q.popleft()
            for v in red_edges[u]:
                if dist[v] == -1:
                    dist[v] = dist[u] + 1
                    q.append(v)
        return dist

    # Precompute distances from all nodes
    red_dist = [None] * (N+1)
    for u in range(1, N+1):
        red_dist[u] = bfs(u)

    # Build black tree and find root (arbitrary, say 1)
    # Then convert to parent-children structure via BFS/DFS
    visited = [False] * (N+1)
    black_parent = [0] * (N+1)
    black_children = [[] for _ in range(N+1)]
    q = deque()
    root = 1
    q.append(root)
    visited[root] = True
    while q:
        u = q.popleft()
        for v in black_edges[u]:
            if not visited[v]:
                visited[v] = True
                black_parent[v] = u
                black_children[u].append(v)
                q.append(v)

    # Compute the answer
    total = 0

    def dfs(u):
        nonlocal total
        res = 0  # max distance in this subtree
        sum_d = 0
        max_d = 0
        for v in black_children[u]:
            d = dfs(v)
            current_dist = red_dist[v][u]
            sum_d += current_dist + d
            if current_dist + d > max_d:
                max_d = current_dist + d
        if len(black_children[u]) == 0:
            return 0
        total += sum_d - max_d
        return max_d

    dfs(root)
    print(total + (N-1))

if __name__ == "__main__":
    main()
0