結果
問題 |
No.922 東北きりきざむたん
|
ユーザー |
![]() |
提出日時 | 2025-04-24 12:27:15 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 6,259 bytes |
コンパイル時間 | 621 ms |
コンパイル使用メモリ | 82,088 KB |
実行使用メモリ | 78,376 KB |
最終ジャッジ日時 | 2025-04-24 12:28:35 |
合計ジャッジ時間 | 5,006 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 1 WA * 4 TLE * 1 -- * 20 |
ソースコード
import sys from sys import stdin from collections import deque, defaultdict def main(): sys.setrecursionlimit(1 << 25) input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]); ptr +=1 M = int(input[ptr]); ptr +=1 Q = int(input[ptr]); ptr +=1 # DSU to find components class DSU: def __init__(self, size): self.parent = list(range(size+1)) self.rank = [0]*(size+1) def find(self, x): if self.parent[x] != x: self.parent[x] = self.find(self.parent[x]) return self.parent[x] def union(self, x, y): x_root = self.find(x) y_root = self.find(y) if x_root == y_root: return if self.rank[x_root] < self.rank[y_root]: self.parent[x_root] = y_root else: self.parent[y_root] = x_root if self.rank[x_root] == self.rank[y_root]: self.rank[x_root] +=1 dsu = DSU(N) edges = [] for _ in range(M): u = int(input[ptr]); ptr +=1 v = int(input[ptr]); ptr +=1 dsu.union(u, v) edges.append((u, v)) # Build adjacency list from the M edges adj = [[] for _ in range(N+1)] for u, v in edges: adj[u].append(v) adj[v].append(u) # Preprocess LCA for each component max_log = 20 parent = [[-1]*(N+1) for _ in range(max_log)] depth = [0]*(N+1) root_map = [0]*(N+1) visited = [False]*(N+1) for u in range(1, N+1): if not visited[u]: root = u queue = deque([root]) visited[root] = True parent[0][root] = -1 depth[root] = 0 root_map[root] = root while queue: current = queue.popleft() root_map[current] = root for v in adj[current]: if not visited[v] and dsu.find(v) == dsu.find(current): visited[v] = True parent[0][v] = current depth[v] = depth[current] + 1 queue.append(v) # Build binary lifting table for k in range(1, max_log): for u in range(1, N+1): if parent[k-1][u] != -1: parent[k][u] = parent[k-1][parent[k-1][u]] else: parent[k][u] = -1 # LCA function def lca(u, v): if depth[u] < depth[v]: u, v = v, u for k in range(max_log-1, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = parent[k][u] if u == v: return u for k in range(max_log-1, -1, -1): if parent[k][u] != -1 and parent[k][u] != parent[k][v]: u = parent[k][u] v = parent[k][v] return parent[0][u] # Distance function def distance(u, v): ancestor = lca(u, v) return depth[u] + depth[v] - 2 * depth[ancestor] # Read queries same_total = 0 cross_queries = [] required = defaultdict(list) # root -> list of nodes for _ in range(Q): a = int(input[ptr]); ptr +=1 b = int(input[ptr]); ptr +=1 if dsu.find(a) == dsu.find(b): same_total += distance(a, b) else: cross_queries.append( (a, b) ) ra = root_map[a] rb = root_map[b] required[ra].append(a) required[rb].append(b) # Process each component with required nodes airport_dist = defaultdict(dict) # root -> {node: dist} for r in required: nodes = required[r] component = set() for node in nodes: component.add(node) # Build adjacency list for the component's tree adj_comp = [[] for _ in range(N+1)] for v in range(1, N+1): if root_map[v] == r and parent[0][v] != -1: u = parent[0][v] adj_comp[u].append(v) adj_comp[v].append(u) # Marked nodes marked = [False]*(N+1) for node in nodes: marked[node] = True # Post-order to compute cnt and sum_dist_sub cnt = [0]*(N+1) sum_dist_sub = [0]*(N+1) stack = [(r, False)] while stack: u, visited_flag = stack.pop() if not visited_flag: stack.append( (u, True) ) # Push children in reverse order to process in order for v in reversed(adj_comp[u]): if parent[0][v] == u and root_map[v] == r: stack.append( (v, False) ) else: current_cnt = 0 current_sum = 0 for v in adj_comp[u]: if parent[0][v] == u and root_map[v] == r: current_cnt += cnt[v] current_sum += sum_dist_sub[v] + cnt[v] if marked[u]: current_cnt += 1 cnt[u] = current_cnt sum_dist_sub[u] = current_sum # Pre-order to compute sum_dist sum_dist = [0]*(N+1) sum_dist[r] = sum_dist_sub[r] total_marked = len(nodes) stack = [r] while stack: u = stack.pop() for v in adj_comp[u]: if parent[0][v] == u and root_map[v] == r: sum_dist[v] = sum_dist[u] - cnt[v] + (total_marked - cnt[v]) stack.append(v) # Find node with minimal sum_dist min_sum = float('inf') airport = r for v in range(1, N+1): if root_map[v] == r and sum_dist[v] < min_sum: min_sum = sum_dist[v] airport = v # Precompute distances for required nodes for node in nodes: d = distance(node, airport) airport_dist[r][node] = d # Process cross_queries cross_total = 0 for a, b in cross_queries: ra = root_map[a] rb = root_map[b] da = airport_dist[ra][a] db = airport_dist[rb][b] cross_total += da + db total = same_total + cross_total print(total) if __name__ == '__main__': main()