結果
問題 |
No.922 東北きりきざむたん
|
ユーザー |
![]() |
提出日時 | 2025-06-12 15:26:22 |
言語 | PyPy3 (7.3.15) |
結果 |
MLE
|
実行時間 | - |
コード長 | 5,421 bytes |
コンパイル時間 | 171 ms |
コンパイル使用メモリ | 82,388 KB |
実行使用メモリ | 849,284 KB |
最終ジャッジ日時 | 2025-06-12 15:26:58 |
合計ジャッジ時間 | 5,052 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 5 MLE * 1 -- * 20 |
ソースコード
import sys from collections import defaultdict, deque sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]); ptr +=1 M = int(input[ptr]); ptr +=1 Q = int(input[ptr]); ptr +=1 parent = list(range(N+1)) rank = [1]*(N+1) def find(u): if parent[u] != u: parent[u] = find(parent[u]) return parent[u] def union(u, v): u_root = find(u) v_root = find(v) if u_root == v_root: return if rank[u_root] < rank[v_root]: parent[u_root] = v_root else: parent[v_root] = u_root if rank[u_root] == rank[v_root]: rank[u_root] +=1 original_adj = defaultdict(list) for _ in range(M): u = int(input[ptr]); ptr +=1 v = int(input[ptr]); ptr +=1 original_adj[u].append(v) original_adj[v].append(u) union(u, v) components = defaultdict(list) for u in range(1, N+1): root = find(u) components[root].append(u) component_adj = defaultdict(dict) for root, nodes in components.items(): adj = defaultdict(list) for u in nodes: for v in original_adj[u]: if find(v) == root: adj[u].append(v) component_adj[root] = adj log_level = 17 parent_table = {} depth = {} for root, nodes in components.items(): p_table = [ [0]*(log_level+1) for _ in range(N+1) ] visited = [False]*(N+1) q = deque() q.append(nodes[0]) visited[nodes[0]] = True p_table[nodes[0]][0] = 0 depth_node = {nodes[0]:0} while q: u = q.popleft() for v in component_adj[root][u]: if not visited[v]: visited[v] = True p_table[v][0] = u depth_node[v] = depth_node[u] + 1 q.append(v) for k in range(1, log_level+1): for v in nodes: p_table[v][k] = p_table[ p_table[v][k-1] ][k-1] parent_table[root] = p_table depth[root] = depth_node def lca(u, v, root_comp): p_table = parent_table[root_comp] depth_u = depth[root_comp][u] depth_v = depth[root_comp][v] if depth_u > depth_v: u, v = v, u depth_u, depth_v = depth_v, depth_u for k in range(log_level, -1, -1): if depth_v - (1 << k) >= depth_u: v = p_table[v][k] depth_v -= (1 << k) if u == v: return u for k in range(log_level, -1, -1): if p_table[u][k] != p_table[v][k]: u = p_table[u][k] v = p_table[v][k] return p_table[u][0] fixed_sum = 0 cnt_a = defaultdict(int) cnt_b = defaultdict(int) for _ in range(Q): a = int(input[ptr]); ptr +=1 b = int(input[ptr]); ptr +=1 root_a = find(a) root_b = find(b) if root_a == root_b: l = lca(a, b, root_a) distance = depth[root_a][a] + depth[root_a][b] - 2 * depth[root_a][l] fixed_sum += distance else: cnt_a[a] +=1 cnt_b[b] +=1 cnt = defaultdict(int) for x in cnt_a: cnt[x] += cnt_a[x] for x in cnt_b: cnt[x] += cnt_b[x] variable_sum = 0 for root, nodes in components.items(): adj = component_adj[root] cnt_nodes = {x: cnt[x] for x in nodes} total_cnt = sum(cnt_nodes.values()) if total_cnt == 0: continue size = {} sum_cnt = {} stack = [(nodes[0], False)] parent_dict = {nodes[0]: -1} while stack: u, visited = stack.pop() if visited: s = cnt_nodes[u] sz = 1 for v in adj[u]: if v != parent_dict[u]: s += sum_cnt[v] sz += size[v] size[u] = sz sum_cnt[u] = s else: stack.append( (u, True) ) for v in adj[u]: if v != parent_dict.get(u, -1): parent_dict[v] = u stack.append( (v, False) ) current = nodes[0] while True: max_child_sum = 0 best_child = -1 for v in adj[current]: if v != parent_dict.get(current, -1): if sum_cnt.get(v, 0) > max_child_sum: max_child_sum = sum_cnt[v] best_child = v if max_child_sum > total_cnt / 2: current = best_child else: break distance = {} q = deque() q.append(current) visited = set() visited.add(current) distance[current] = 0 while q: u = q.popleft() for v in adj[u]: if v not in visited: visited.add(v) distance[v] = distance[u] + 1 q.append(v) sum_component = 0 for x in nodes: sum_component += cnt[x] * distance[x] variable_sum += sum_component print(fixed_sum + variable_sum) if __name__ == '__main__': main()