結果

問題 No.922 東北きりきざむたん
ユーザー gew1fw
提出日時 2025-06-12 15:25:13
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 5,909 bytes
コンパイル時間 292 ms
コンパイル使用メモリ 82,356 KB
実行使用メモリ 125,680 KB
最終ジャッジ日時 2025-06-12 15:25:57
合計ジャッジ時間 4,749 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 1 WA * 4 TLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    N, M, Q = map(int, sys.stdin.readline().split())
    edges = [[] for _ in range(N+1)]
    for _ in range(M):
        u, v = map(int, sys.stdin.readline().split())
        edges[u].append(v)
        edges[v].append(u)
    
    # Compute connected components using BFS
    visited = [False] * (N + 1)
    components = []
    for u in range(1, N + 1):
        if not visited[u]:
            q = deque()
            q.append(u)
            visited[u] = True
            comp = []
            while q:
                v = q.popleft()
                comp.append(v)
                for nei in edges[v]:
                    if not visited[nei]:
                        visited[nei] = True
                        q.append(nei)
            components.append(comp)
    
    # For each component, compute parent, depth, and children
    component_info = []
    for comp in components:
        parent = {}
        depth = {}
        children = {}
        for u in comp:
            parent[u] = -1
            depth[u] = 0
            children[u] = []
        if not comp:
            continue
        root = comp[0]
        q = deque()
        q.append(root)
        parent[root] = -1
        while q:
            u = q.popleft()
            for v in edges[u]:
                if v in parent and parent[v] == -1 and v != parent[u]:
                    parent[v] = u
                    depth[v] = depth[u] + 1
                    children[u].append(v)
                    q.append(v)
        component_info.append((parent, depth, children))
    
    # Read queries and separate within and cross
    within_sum = 0
    cross_queries = []
    query_nodes = {}
    for _ in range(Q):
        a, b = map(int, sys.stdin.readline().split())
        a_comp = None
        b_comp = None
        for i, comp in enumerate(components):
            if a in comp:
                a_comp = i
            if b in comp:
                b_comp = i
        if a_comp == b_comp:
            # Compute distance within component
            parent_a, depth_a, children_a = component_info[a_comp]
            u, v = a, b
            while depth_a[u] > depth_a[v]:
                u = parent_a[u]
            while depth_a[v] > depth_a[u]:
                v = parent_a[v]
            while u != v:
                u = parent_a[u]
                v = parent_a[v]
            lca_node = u
            dist = depth_a[a] + depth_a[b] - 2 * depth_a[lca_node]
            within_sum += dist
        else:
            cross_queries.append((a, b, a_comp, b_comp))
            if a_comp not in query_nodes:
                query_nodes[a_comp] = []
            query_nodes[a_comp].append(a)
            if b_comp not in query_nodes:
                query_nodes[b_comp] = []
            query_nodes[b_comp].append(b)
    
    # Identify required components
    required_components = set()
    for a, b, ac, bc in cross_queries:
        required_components.add(ac)
        required_components.add(bc)
    
    # Compute optimal airport for each required component
    airport = {}
    for comp_id in required_components:
        parent, depth, children = component_info[comp_id]
        nodes = components[comp_id]
        q_nodes = query_nodes.get(comp_id, [])
        if not q_nodes:
            continue
        # Build size and res
        size = {u: 0 for u in nodes}
        res = {u: 0 for u in nodes}
        root = nodes[0]
        stack = [(root, False)]
        while stack:
            u, processed = stack.pop()
            if processed:
                # Calculate size and res for u
                size_u = 0
                res_u = 0
                if u in q_nodes:
                    size_u += 1
                    res_u += 0  # distance from u to itself is 0
                for v in children[u]:
                    size_u += size[v]
                    res_u += res[v] + size[v]
                size[u] = size_u
                res[u] = res_u
            else:
                stack.append((u, True))
                # Push children in reverse order to process in correct order
                for v in reversed(children[u]):
                    stack.append((v, False))
        # Pre-order traversal to update res
        stack = [(root, -1)]
        while stack:
            u, p = stack.pop()
            for v in children[u]:
                # Compute res[v] based on res[u]
                total = len(q_nodes)
                res[v] = res[u] - size[v] + (total - size[v])
                stack.append((v, u))
        # Find the node with minimal res
        min_res = float('inf')
        best_u = root
        for u in nodes:
            if res[u] < min_res:
                min_res = res[u]
                best_u = u
        airport[comp_id] = best_u
    
    # Compute cross sum
    cross_sum = 0
    for a, b, ac, bc in cross_queries:
        u_a = airport[ac]
        u_b = airport[bc]
        # Compute distance a to u_a
        parent_a, depth_a, children_a = component_info[ac]
        x, y = a, u_a
        if depth_a[x] > depth_a[y]:
            x, y = y, x
        while depth_a[y] > depth_a[x]:
            y = parent_a[y]
        while x != y:
            x = parent_a[x]
            y = parent_a[y]
        lca_a = x
        dist_a = depth_a[a] + depth_a[u_a] - 2 * depth_a[lca_a]
        # Compute distance b to u_b
        parent_b, depth_b, children_b = component_info[bc]
        x, y = b, u_b
        if depth_b[x] > depth_b[y]:
            x, y = y, x
        while depth_b[y] > depth_b[x]:
            y = parent_b[y]
        while x != y:
            x = parent_b[x]
            y = parent_b[y]
        lca_b = x
        dist_b = depth_b[b] + depth_b[u_b] - 2 * depth_b[lca_b]
        cross_sum += dist_a + dist_b
    
    total = within_sum + cross_sum
    print(total)

if __name__ == "__main__":
    main()
0