結果

問題 No.922 東北きりきざむたん
ユーザー lam6er
提出日時 2025-04-09 20:59:59
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,240 bytes
コンパイル時間 325 ms
コンパイル使用メモリ 82,300 KB
実行使用メモリ 145,708 KB
最終ジャッジ日時 2025-04-09 21:01:01
合計ジャッジ時間 5,915 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 6 TLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import deque, defaultdict

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, M, Q = map(int, input[ptr:ptr+3])
    ptr += 3

    edges = [[] for _ in range(N+1)]
    uf = list(range(N+1))

    def find(u):
        while uf[u] != u:
            uf[u] = uf[uf[u]]
            u = uf[u]
        return u

    def union(u, v):
        u_root = find(u)
        v_root = find(v)
        if u_root != v_root:
            uf[v_root] = u_root

    for _ in range(M):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr += 2
        edges[u].append(v)
        edges[v].append(u)
        union(u, v)

    comp = defaultdict(list)
    for city in range(1, N+1):
        comp[find(city)].append(city)

    S = set()
    req = defaultdict(list)
    queries = []
    for _ in range(Q):
        a = int(input[ptr])
        b = int(input[ptr+1])
        ptr += 2
        queries.append((a, b))
        ca = find(a)
        cb = find(b)
        if ca != cb:
            S.add(ca)
            S.add(cb)
            req[ca].append(a)
            req[cb].append(b)

    comp_nodes = {}
    comp_edges = {}
    for c in comp:
        nodes = comp[c]
        comp_nodes[c] = nodes
        edge = defaultdict(list)
        for u in nodes:
            for v in edges[u]:
                if find(v) == c:
                    edge[u].append(v)
        comp_edges[c] = edge

    lca_data = {}
    for c in comp:
        nodes = comp_nodes[c]
        edge = comp_edges[c]
        parent = {}
        depth = {}
        root = nodes[0]
        queue = deque([root])
        parent[root] = -1
        depth[root] = 0
        while queue:
            u = queue.popleft()
            for v in edge[u]:
                if v != parent[u]:
                    parent[v] = u
                    depth[v] = depth[u] + 1
                    queue.append(v)
        lca_data[c] = (parent, depth)

    def get_dist(u, v):
        if u == v:
            return 0
        cu = find(u)
        cv = find(v)
        if cu != cv:
            return -1
        parent, depth = lca_data[cu]
        pu, pv = u, v
        while depth[pu] > depth[pv]:
            pu = parent[pu]
        while depth[pv] > depth[pu]:
            pv = parent[pv]
        while pu != pv:
            pu = parent[pu]
            pv = parent[pv]
        lca_node = pu
        return depth[u] + depth[v] - 2 * depth[lca_node]

    airport_pos = {}
    dist_from_airport = defaultdict(dict)
    for c in S:
        nodes = comp_nodes[c]
        edge = comp_edges[c]
        req_vertices = req[c]

        sum_dist_min = float('inf')
        best_x = -1

        sum_count = {}
        for u in req_vertices:
            visited = set()
            q = deque()
            q.append((u, 0))
            visited.add(u)
            while q:
                v, d = q.popleft()
                sum_count[v] = sum_count.get(v, 0) + d
                for w in edge[v]:
                    if w not in visited:
                        visited.add(w)
                        q.append((w, d+1))

        for v in sum_count:
            if sum_count[v] < sum_dist_min:
                sum_dist_min = sum_count[v]
                best_x = v

        airport_pos[c] = best_x

        visited = set()
        q = deque()
        q.append((best_x, 0))
        visited.add(best_x)
        dist_from_airport[c][best_x] = 0
        while q:
            u, d = q.popleft()
            for v in edge[u]:
                if v not in visited:
                    visited.add(v)
                    dist_from_airport[c][v] = d + 1
                    q.append((v, d + 1))

    total = 0
    for a, b in queries:
        ca = find(a)
        cb = find(b)
        if ca == cb:
            dist = get_dist(a, b)
            total += dist
        else:
            if ca not in airport_pos or cb not in airport_pos:
                pass
            else:
                xa = airport_pos[ca]
                xb = airport_pos[cb]
                da = dist_from_airport[ca].get(a, 0)
                db = dist_from_airport[cb].get(b, 0)
                total += da + db
    print(total)

if __name__ == '__main__':
    main()
0