結果

問題 No.922 東北きりきざむたん
ユーザー qwewe
提出日時 2025-04-24 12:27:16
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,798 bytes
コンパイル時間 240 ms
コンパイル使用メモリ 82,068 KB
実行使用メモリ 849,532 KB
最終ジャッジ日時 2025-04-24 12:28:39
合計ジャッジ時間 3,425 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other WA * 5 MLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from sys import setrecursionlimit
from collections import defaultdict, deque
input = sys.stdin.readline
setrecursionlimit(1 << 25)

def main():
    sys.setrecursionlimit(1 << 25)
    N, M, Q = map(int, input().split())
    edges = [[] for _ in range(N+1)]
    parent = [i for i in range(N+1)]
    rank = [1]*(N+1)
    
    def find(u):
        while parent[u] != u:
            parent[u] = parent[parent[u]]
            u = parent[u]
        return u
    
    def union(u, v):
        u = find(u)
        v = find(v)
        if u == v:
            return
        if rank[u] < rank[v]:
            parent[u] = v
            rank[v] += rank[u]
        else:
            parent[v] = u
            rank[u] += rank[v]
    
    adj_components = defaultdict(list)
    for _ in range(M):
        u, v = map(int, input().split())
        union(u, v)
        adj_components[find(u)].append((u, v))
    
    component_roots = set(find(u) for u in range(1, N+1))
    component_info = {}
    for root in component_roots:
        edges = adj_components.get(root, [])
        adj = defaultdict(list)
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)
        component_info[root] = {
            'adj': adj,
            'depth': {},
            'up': None,
            'nodes': []
        }
        visited = set()
        q = deque([root])
        visited.add(root)
        parent_dict = {root: -1}
        depth_dict = {root: 0}
        nodes = []
        while q:
            u = q.popleft()
            nodes.append(u)
            for v in adj[u]:
                if v not in visited:
                    visited.add(v)
                    parent_dict[v] = u
                    depth_dict[v] = depth_dict[u] + 1
                    q.append(v)
        component_info[root]['nodes'] = nodes
        component_info[root]['depth'] = depth_dict
        LOG = 20
        up = [[-1]*(N+1) for _ in range(LOG)]
        for u in nodes:
            up[0][u] = parent_dict.get(u, -1)
        for k in range(1, LOG):
            for u in nodes:
                if up[k-1][u] != -1:
                    up[k][u] = up[k-1][up[k-1][u]]
                else:
                    up[k][u] = -1
        component_info[root]['up'] = up
    
    def lca(u, v, root):
        up = component_info[root]['up']
        depth = component_info[root]['depth']
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(19, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = up[k][u]
        if u == v:
            return u
        for k in range(19, -1, -1):
            if up[k][u] != -1 and up[k][u] != up[k][v]:
                u = up[k][u]
                v = up[k][v]
        return up[0][u]
    
    total = 0
    S = set()
    X = defaultdict(list)
    for _ in range(Q):
        a, b = map(int, input().split())
        a_root = find(a)
        b_root = find(b)
        if a_root == b_root:
            depth = component_info[a_root]['depth']
            lca_node = lca(a, b, a_root)
            distance = depth[a] + depth[b] - 2 * depth[lca_node]
            total += distance
        else:
            S.add(a_root)
            S.add(b_root)
            X[a_root].append(a)
            X[b_root].append(b)
    
    for root in S:
        nodes = component_info[root]['nodes']
        adj = component_info[root]['adj']
        X_C = X[root]
        if not X_C:
            continue
        marked = defaultdict(bool)
        for node in X_C:
            marked[node] = True
        
        total_X = len(X_C)
        cnt = defaultdict(int)
        sum_sub = defaultdict(int)
        parent_dict = {}
        stack = [(root, None)]
        post_order = []
        while stack:
            u, p = stack.pop()
            post_order.append(u)
            parent_dict[u] = p
            for v in adj[u]:
                if v != p:
                    stack.append((v, u))
        post_order.reverse()
        for u in post_order:
            cnt[u] = 1 if marked[u] else 0
            sum_sub[u] = 0
            for v in adj[u]:
                if v == parent_dict.get(u, None):
                    continue
                cnt[u] += cnt[v]
                sum_sub[u] += sum_sub[v] + cnt[v]
        
        sum_dist = defaultdict(int)
        sum_dist[root] = sum_sub[root]
        q = deque()
        q.append((root, None))
        while q:
            u, p = q.popleft()
            for v in adj[u]:
                if v == p:
                    continue
                sum_dist[v] = sum_dist[u] - cnt[v] + (total_X - cnt[v])
                q.append((v, u))
        
        min_sum = min(sum_dist[node] for node in nodes)
        total += min_sum
    
    print(total)

if __name__ == '__main__':
    main()
0