結果

問題 No.922 東北きりきざむたん
ユーザー gew1fw
提出日時 2025-06-12 15:26:22
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 5,421 bytes
コンパイル時間 171 ms
コンパイル使用メモリ 82,388 KB
実行使用メモリ 849,284 KB
最終ジャッジ日時 2025-06-12 15:26:58
合計ジャッジ時間 5,052 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 5 MLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict, deque

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr]); ptr +=1
    M = int(input[ptr]); ptr +=1
    Q = int(input[ptr]); ptr +=1

    parent = list(range(N+1))
    rank = [1]*(N+1)

    def find(u):
        if parent[u] != u:
            parent[u] = find(parent[u])
        return parent[u]

    def union(u, v):
        u_root = find(u)
        v_root = find(v)
        if u_root == v_root:
            return
        if rank[u_root] < rank[v_root]:
            parent[u_root] = v_root
        else:
            parent[v_root] = u_root
            if rank[u_root] == rank[v_root]:
                rank[u_root] +=1

    original_adj = defaultdict(list)
    for _ in range(M):
        u = int(input[ptr]); ptr +=1
        v = int(input[ptr]); ptr +=1
        original_adj[u].append(v)
        original_adj[v].append(u)
        union(u, v)

    components = defaultdict(list)
    for u in range(1, N+1):
        root = find(u)
        components[root].append(u)

    component_adj = defaultdict(dict)
    for root, nodes in components.items():
        adj = defaultdict(list)
        for u in nodes:
            for v in original_adj[u]:
                if find(v) == root:
                    adj[u].append(v)
        component_adj[root] = adj

    log_level = 17
    parent_table = {}
    depth = {}
    for root, nodes in components.items():
        p_table = [ [0]*(log_level+1) for _ in range(N+1) ]
        visited = [False]*(N+1)
        q = deque()
        q.append(nodes[0])
        visited[nodes[0]] = True
        p_table[nodes[0]][0] = 0
        depth_node = {nodes[0]:0}
        while q:
            u = q.popleft()
            for v in component_adj[root][u]:
                if not visited[v]:
                    visited[v] = True
                    p_table[v][0] = u
                    depth_node[v] = depth_node[u] + 1
                    q.append(v)
        for k in range(1, log_level+1):
            for v in nodes:
                p_table[v][k] = p_table[ p_table[v][k-1] ][k-1]
        parent_table[root] = p_table
        depth[root] = depth_node

    def lca(u, v, root_comp):
        p_table = parent_table[root_comp]
        depth_u = depth[root_comp][u]
        depth_v = depth[root_comp][v]
        if depth_u > depth_v:
            u, v = v, u
            depth_u, depth_v = depth_v, depth_u
        for k in range(log_level, -1, -1):
            if depth_v - (1 << k) >= depth_u:
                v = p_table[v][k]
                depth_v -= (1 << k)
        if u == v:
            return u
        for k in range(log_level, -1, -1):
            if p_table[u][k] != p_table[v][k]:
                u = p_table[u][k]
                v = p_table[v][k]
        return p_table[u][0]

    fixed_sum = 0
    cnt_a = defaultdict(int)
    cnt_b = defaultdict(int)

    for _ in range(Q):
        a = int(input[ptr]); ptr +=1
        b = int(input[ptr]); ptr +=1
        root_a = find(a)
        root_b = find(b)
        if root_a == root_b:
            l = lca(a, b, root_a)
            distance = depth[root_a][a] + depth[root_a][b] - 2 * depth[root_a][l]
            fixed_sum += distance
        else:
            cnt_a[a] +=1
            cnt_b[b] +=1

    cnt = defaultdict(int)
    for x in cnt_a:
        cnt[x] += cnt_a[x]
    for x in cnt_b:
        cnt[x] += cnt_b[x]

    variable_sum = 0

    for root, nodes in components.items():
        adj = component_adj[root]
        cnt_nodes = {x: cnt[x] for x in nodes}

        total_cnt = sum(cnt_nodes.values())
        if total_cnt == 0:
            continue

        size = {}
        sum_cnt = {}
        stack = [(nodes[0], False)]
        parent_dict = {nodes[0]: -1}
        while stack:
            u, visited = stack.pop()
            if visited:
                s = cnt_nodes[u]
                sz = 1
                for v in adj[u]:
                    if v != parent_dict[u]:
                        s += sum_cnt[v]
                        sz += size[v]
                size[u] = sz
                sum_cnt[u] = s
            else:
                stack.append( (u, True) )
                for v in adj[u]:
                    if v != parent_dict.get(u, -1):
                        parent_dict[v] = u
                        stack.append( (v, False) )

        current = nodes[0]
        while True:
            max_child_sum = 0
            best_child = -1
            for v in adj[current]:
                if v != parent_dict.get(current, -1):
                    if sum_cnt.get(v, 0) > max_child_sum:
                        max_child_sum = sum_cnt[v]
                        best_child = v
            if max_child_sum > total_cnt / 2:
                current = best_child
            else:
                break

        distance = {}
        q = deque()
        q.append(current)
        visited = set()
        visited.add(current)
        distance[current] = 0
        while q:
            u = q.popleft()
            for v in adj[u]:
                if v not in visited:
                    visited.add(v)
                    distance[v] = distance[u] + 1
                    q.append(v)

        sum_component = 0
        for x in nodes:
            sum_component += cnt[x] * distance[x]
        variable_sum += sum_component

    print(fixed_sum + variable_sum)

if __name__ == '__main__':
    main()
0