結果
問題 |
No.922 東北きりきざむたん
|
ユーザー |
![]() |
提出日時 | 2025-04-24 12:27:16 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,798 bytes |
コンパイル時間 | 240 ms |
コンパイル使用メモリ | 82,068 KB |
実行使用メモリ | 849,532 KB |
最終ジャッジ日時 | 2025-04-24 12:28:39 |
合計ジャッジ時間 | 3,425 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | WA * 5 MLE * 1 -- * 20 |
ソースコード
import sys from sys import stdin from sys import setrecursionlimit from collections import defaultdict, deque input = sys.stdin.readline setrecursionlimit(1 << 25) def main(): sys.setrecursionlimit(1 << 25) N, M, Q = map(int, input().split()) edges = [[] for _ in range(N+1)] parent = [i for i in range(N+1)] rank = [1]*(N+1) def find(u): while parent[u] != u: parent[u] = parent[parent[u]] u = parent[u] return u def union(u, v): u = find(u) v = find(v) if u == v: return if rank[u] < rank[v]: parent[u] = v rank[v] += rank[u] else: parent[v] = u rank[u] += rank[v] adj_components = defaultdict(list) for _ in range(M): u, v = map(int, input().split()) union(u, v) adj_components[find(u)].append((u, v)) component_roots = set(find(u) for u in range(1, N+1)) component_info = {} for root in component_roots: edges = adj_components.get(root, []) adj = defaultdict(list) for u, v in edges: adj[u].append(v) adj[v].append(u) component_info[root] = { 'adj': adj, 'depth': {}, 'up': None, 'nodes': [] } visited = set() q = deque([root]) visited.add(root) parent_dict = {root: -1} depth_dict = {root: 0} nodes = [] while q: u = q.popleft() nodes.append(u) for v in adj[u]: if v not in visited: visited.add(v) parent_dict[v] = u depth_dict[v] = depth_dict[u] + 1 q.append(v) component_info[root]['nodes'] = nodes component_info[root]['depth'] = depth_dict LOG = 20 up = [[-1]*(N+1) for _ in range(LOG)] for u in nodes: up[0][u] = parent_dict.get(u, -1) for k in range(1, LOG): for u in nodes: if up[k-1][u] != -1: up[k][u] = up[k-1][up[k-1][u]] else: up[k][u] = -1 component_info[root]['up'] = up def lca(u, v, root): up = component_info[root]['up'] depth = component_info[root]['depth'] if depth[u] < depth[v]: u, v = v, u for k in range(19, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = up[k][u] if u == v: return u for k in range(19, -1, -1): if up[k][u] != -1 and up[k][u] != up[k][v]: u = up[k][u] v = up[k][v] return up[0][u] total = 0 S = set() X = defaultdict(list) for _ in range(Q): a, b = map(int, input().split()) a_root = find(a) b_root = find(b) if a_root == b_root: depth = component_info[a_root]['depth'] lca_node = lca(a, b, a_root) distance = depth[a] + depth[b] - 2 * depth[lca_node] total += distance else: S.add(a_root) S.add(b_root) X[a_root].append(a) X[b_root].append(b) for root in S: nodes = component_info[root]['nodes'] adj = component_info[root]['adj'] X_C = X[root] if not X_C: continue marked = defaultdict(bool) for node in X_C: marked[node] = True total_X = len(X_C) cnt = defaultdict(int) sum_sub = defaultdict(int) parent_dict = {} stack = [(root, None)] post_order = [] while stack: u, p = stack.pop() post_order.append(u) parent_dict[u] = p for v in adj[u]: if v != p: stack.append((v, u)) post_order.reverse() for u in post_order: cnt[u] = 1 if marked[u] else 0 sum_sub[u] = 0 for v in adj[u]: if v == parent_dict.get(u, None): continue cnt[u] += cnt[v] sum_sub[u] += sum_sub[v] + cnt[v] sum_dist = defaultdict(int) sum_dist[root] = sum_sub[root] q = deque() q.append((root, None)) while q: u, p = q.popleft() for v in adj[u]: if v == p: continue sum_dist[v] = sum_dist[u] - cnt[v] + (total_X - cnt[v]) q.append((v, u)) min_sum = min(sum_dist[node] for node in nodes) total += min_sum print(total) if __name__ == '__main__': main()