結果

問題 No.922 東北きりきざむたん
ユーザー qwewe
提出日時 2025-05-14 12:51:04
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,104 bytes
コンパイル時間 507 ms
コンパイル使用メモリ 82,364 KB
実行使用メモリ 199,840 KB
最終ジャッジ日時 2025-05-14 12:52:03
合計ジャッジ時間 14,373 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 8 WA * 18
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import deque, defaultdict

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0
    N, M, Q = map(int, input[ptr:ptr+3])
    ptr +=3

    adj = [[] for _ in range(N+1)]
    for _ in range(M):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr +=2
        adj[u].append(v)
        adj[v].append(u)

    LOG = 20
    parent = [-1]*(N+1)
    depth = [0]*(N+1)
    component_id = [0]*(N+1)
    visited = [False]*(N+1)
    components = []
    cid = 0
    children = [[] for _ in range(N+1)]

    for u in range(1, N+1):
        if not visited[u]:
            cid +=1
            q = deque()
            q.append(u)
            visited[u] = True
            component = []
            parent[u] = -1
            depth[u] = 0
            while q:
                v = q.popleft()
                component.append(v)
                component_id[v] = cid
                for nei in adj[v]:
                    if not visited[nei]:
                        visited[nei] = True
                        parent[nei] = v
                        depth[nei] = depth[v] +1
                        q.append(nei)
                        children[v].append(nei)
            components.append(component)

    up = [[-1]*(LOG) for _ in range(N+1)]
    for u in range(1, N+1):
        up[u][0] = parent[u]
    for k in range(1, LOG):
        for u in range(1, N+1):
            if up[u][k-1] != -1:
                up[u][k] = up[up[u][k-1]][k-1]
            else:
                up[u][k] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if up[u][k] != -1 and depth[u] - (1 << k) >= depth[v]:
                u = up[u][k]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if up[u][k] != -1 and up[u][k] != up[v][k]:
                u = up[u][k]
                v = up[v][k]
        return parent[u]

    def distance(u, v):
        ancestor = lca(u, v)
        return depth[u] + depth[v] - 2 * depth[ancestor]

    sum_total = 0
    S = defaultdict(set)
    T = defaultdict(set)
    needs_airport = defaultdict(bool)

    for _ in range(Q):
        a = int(input[ptr])
        b = int(input[ptr+1])
        ptr +=2
        ca = component_id[a]
        cb = component_id[b]
        if ca == cb:
            sum_total += distance(a, b)
        else:
            S[ca].add(a)
            T[cb].add(b)
            needs_airport[ca] = True
            needs_airport[cb] = True

    component_nodes = defaultdict(list)
    for u in range(1, N+1):
        component_nodes[component_id[u]].append(u)

    for c in needs_airport:
        if not needs_airport[c]:
            continue
        targets = S[c].union(T[c])
        if not targets:
            sum_total +=0
            continue
        nodes = component_nodes[c]
        root = -1
        for u in nodes:
            if parent[u] == -1:
                root = u
                break
        count = defaultdict(int)
        sum_dist = defaultdict(int)
        post_order = []
        stack = [(root, False)]
        while stack:
            u, processed = stack.pop()
            if processed:
                post_order.append(u)
                cnt = 1 if u in targets else 0
                s = 0
                for v in children[u]:
                    cnt += count[v]
                    s += sum_dist[v] + count[v]
                count[u] = cnt
                sum_dist[u] = s
            else:
                stack.append( (u, True) )
                for v in reversed(children[u]):
                    stack.append( (v, False) )
        total_targets = len(targets)
        stack = [root]
        while stack:
            u = stack.pop()
            for v in children[u]:
                sum_dist[v] = sum_dist[u] - count[v] + (total_targets - count[v])
                stack.append(v)
        min_sum = min( sum_dist[u] for u in nodes )
        sum_total += min_sum

    print(sum_total)

if __name__ == '__main__':
    main()
0