結果
問題 |
No.922 東北きりきざむたん
|
ユーザー |
![]() |
提出日時 | 2025-05-14 12:51:04 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,104 bytes |
コンパイル時間 | 507 ms |
コンパイル使用メモリ | 82,364 KB |
実行使用メモリ | 199,840 KB |
最終ジャッジ日時 | 2025-05-14 12:52:03 |
合計ジャッジ時間 | 14,373 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 8 WA * 18 |
ソースコード
import sys from sys import stdin from collections import deque, defaultdict def main(): sys.setrecursionlimit(1 << 25) input = sys.stdin.read().split() ptr = 0 N, M, Q = map(int, input[ptr:ptr+3]) ptr +=3 adj = [[] for _ in range(N+1)] for _ in range(M): u = int(input[ptr]) v = int(input[ptr+1]) ptr +=2 adj[u].append(v) adj[v].append(u) LOG = 20 parent = [-1]*(N+1) depth = [0]*(N+1) component_id = [0]*(N+1) visited = [False]*(N+1) components = [] cid = 0 children = [[] for _ in range(N+1)] for u in range(1, N+1): if not visited[u]: cid +=1 q = deque() q.append(u) visited[u] = True component = [] parent[u] = -1 depth[u] = 0 while q: v = q.popleft() component.append(v) component_id[v] = cid for nei in adj[v]: if not visited[nei]: visited[nei] = True parent[nei] = v depth[nei] = depth[v] +1 q.append(nei) children[v].append(nei) components.append(component) up = [[-1]*(LOG) for _ in range(N+1)] for u in range(1, N+1): up[u][0] = parent[u] for k in range(1, LOG): for u in range(1, N+1): if up[u][k-1] != -1: up[u][k] = up[up[u][k-1]][k-1] else: up[u][k] = -1 def lca(u, v): if depth[u] < depth[v]: u, v = v, u for k in reversed(range(LOG)): if up[u][k] != -1 and depth[u] - (1 << k) >= depth[v]: u = up[u][k] if u == v: return u for k in reversed(range(LOG)): if up[u][k] != -1 and up[u][k] != up[v][k]: u = up[u][k] v = up[v][k] return parent[u] def distance(u, v): ancestor = lca(u, v) return depth[u] + depth[v] - 2 * depth[ancestor] sum_total = 0 S = defaultdict(set) T = defaultdict(set) needs_airport = defaultdict(bool) for _ in range(Q): a = int(input[ptr]) b = int(input[ptr+1]) ptr +=2 ca = component_id[a] cb = component_id[b] if ca == cb: sum_total += distance(a, b) else: S[ca].add(a) T[cb].add(b) needs_airport[ca] = True needs_airport[cb] = True component_nodes = defaultdict(list) for u in range(1, N+1): component_nodes[component_id[u]].append(u) for c in needs_airport: if not needs_airport[c]: continue targets = S[c].union(T[c]) if not targets: sum_total +=0 continue nodes = component_nodes[c] root = -1 for u in nodes: if parent[u] == -1: root = u break count = defaultdict(int) sum_dist = defaultdict(int) post_order = [] stack = [(root, False)] while stack: u, processed = stack.pop() if processed: post_order.append(u) cnt = 1 if u in targets else 0 s = 0 for v in children[u]: cnt += count[v] s += sum_dist[v] + count[v] count[u] = cnt sum_dist[u] = s else: stack.append( (u, True) ) for v in reversed(children[u]): stack.append( (v, False) ) total_targets = len(targets) stack = [root] while stack: u = stack.pop() for v in children[u]: sum_dist[v] = sum_dist[u] - count[v] + (total_targets - count[v]) stack.append(v) min_sum = min( sum_dist[u] for u in nodes ) sum_total += min_sum print(sum_total) if __name__ == '__main__': main()