結果
問題 |
No.922 東北きりきざむたん
|
ユーザー |
![]() |
提出日時 | 2025-04-09 20:59:59 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 4,240 bytes |
コンパイル時間 | 325 ms |
コンパイル使用メモリ | 82,300 KB |
実行使用メモリ | 145,708 KB |
最終ジャッジ日時 | 2025-04-09 21:01:01 |
合計ジャッジ時間 | 5,915 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 6 TLE * 1 -- * 19 |
ソースコード
import sys from sys import stdin from collections import deque, defaultdict sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() ptr = 0 N, M, Q = map(int, input[ptr:ptr+3]) ptr += 3 edges = [[] for _ in range(N+1)] uf = list(range(N+1)) def find(u): while uf[u] != u: uf[u] = uf[uf[u]] u = uf[u] return u def union(u, v): u_root = find(u) v_root = find(v) if u_root != v_root: uf[v_root] = u_root for _ in range(M): u = int(input[ptr]) v = int(input[ptr+1]) ptr += 2 edges[u].append(v) edges[v].append(u) union(u, v) comp = defaultdict(list) for city in range(1, N+1): comp[find(city)].append(city) S = set() req = defaultdict(list) queries = [] for _ in range(Q): a = int(input[ptr]) b = int(input[ptr+1]) ptr += 2 queries.append((a, b)) ca = find(a) cb = find(b) if ca != cb: S.add(ca) S.add(cb) req[ca].append(a) req[cb].append(b) comp_nodes = {} comp_edges = {} for c in comp: nodes = comp[c] comp_nodes[c] = nodes edge = defaultdict(list) for u in nodes: for v in edges[u]: if find(v) == c: edge[u].append(v) comp_edges[c] = edge lca_data = {} for c in comp: nodes = comp_nodes[c] edge = comp_edges[c] parent = {} depth = {} root = nodes[0] queue = deque([root]) parent[root] = -1 depth[root] = 0 while queue: u = queue.popleft() for v in edge[u]: if v != parent[u]: parent[v] = u depth[v] = depth[u] + 1 queue.append(v) lca_data[c] = (parent, depth) def get_dist(u, v): if u == v: return 0 cu = find(u) cv = find(v) if cu != cv: return -1 parent, depth = lca_data[cu] pu, pv = u, v while depth[pu] > depth[pv]: pu = parent[pu] while depth[pv] > depth[pu]: pv = parent[pv] while pu != pv: pu = parent[pu] pv = parent[pv] lca_node = pu return depth[u] + depth[v] - 2 * depth[lca_node] airport_pos = {} dist_from_airport = defaultdict(dict) for c in S: nodes = comp_nodes[c] edge = comp_edges[c] req_vertices = req[c] sum_dist_min = float('inf') best_x = -1 sum_count = {} for u in req_vertices: visited = set() q = deque() q.append((u, 0)) visited.add(u) while q: v, d = q.popleft() sum_count[v] = sum_count.get(v, 0) + d for w in edge[v]: if w not in visited: visited.add(w) q.append((w, d+1)) for v in sum_count: if sum_count[v] < sum_dist_min: sum_dist_min = sum_count[v] best_x = v airport_pos[c] = best_x visited = set() q = deque() q.append((best_x, 0)) visited.add(best_x) dist_from_airport[c][best_x] = 0 while q: u, d = q.popleft() for v in edge[u]: if v not in visited: visited.add(v) dist_from_airport[c][v] = d + 1 q.append((v, d + 1)) total = 0 for a, b in queries: ca = find(a) cb = find(b) if ca == cb: dist = get_dist(a, b) total += dist else: if ca not in airport_pos or cb not in airport_pos: pass else: xa = airport_pos[ca] xb = airport_pos[cb] da = dist_from_airport[ca].get(a, 0) db = dist_from_airport[cb].get(b, 0) total += da + db print(total) if __name__ == '__main__': main()