結果
| 問題 |
No.922 東北きりきざむたん
|
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 20:37:04 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 5,421 bytes |
| コンパイル時間 | 209 ms |
| コンパイル使用メモリ | 82,396 KB |
| 実行使用メモリ | 849,320 KB |
| 最終ジャッジ日時 | 2025-06-12 20:37:28 |
| 合計ジャッジ時間 | 4,586 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 5 MLE * 1 -- * 20 |
ソースコード
import sys
from collections import defaultdict, deque
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr]); ptr +=1
M = int(input[ptr]); ptr +=1
Q = int(input[ptr]); ptr +=1
parent = list(range(N+1))
rank = [1]*(N+1)
def find(u):
if parent[u] != u:
parent[u] = find(parent[u])
return parent[u]
def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
if rank[u_root] < rank[v_root]:
parent[u_root] = v_root
else:
parent[v_root] = u_root
if rank[u_root] == rank[v_root]:
rank[u_root] +=1
original_adj = defaultdict(list)
for _ in range(M):
u = int(input[ptr]); ptr +=1
v = int(input[ptr]); ptr +=1
original_adj[u].append(v)
original_adj[v].append(u)
union(u, v)
components = defaultdict(list)
for u in range(1, N+1):
root = find(u)
components[root].append(u)
component_adj = defaultdict(dict)
for root, nodes in components.items():
adj = defaultdict(list)
for u in nodes:
for v in original_adj[u]:
if find(v) == root:
adj[u].append(v)
component_adj[root] = adj
log_level = 17
parent_table = {}
depth = {}
for root, nodes in components.items():
p_table = [ [0]*(log_level+1) for _ in range(N+1) ]
visited = [False]*(N+1)
q = deque()
q.append(nodes[0])
visited[nodes[0]] = True
p_table[nodes[0]][0] = 0
depth_node = {nodes[0]:0}
while q:
u = q.popleft()
for v in component_adj[root][u]:
if not visited[v]:
visited[v] = True
p_table[v][0] = u
depth_node[v] = depth_node[u] + 1
q.append(v)
for k in range(1, log_level+1):
for v in nodes:
p_table[v][k] = p_table[ p_table[v][k-1] ][k-1]
parent_table[root] = p_table
depth[root] = depth_node
def lca(u, v, root_comp):
p_table = parent_table[root_comp]
depth_u = depth[root_comp][u]
depth_v = depth[root_comp][v]
if depth_u > depth_v:
u, v = v, u
depth_u, depth_v = depth_v, depth_u
for k in range(log_level, -1, -1):
if depth_v - (1 << k) >= depth_u:
v = p_table[v][k]
depth_v -= (1 << k)
if u == v:
return u
for k in range(log_level, -1, -1):
if p_table[u][k] != p_table[v][k]:
u = p_table[u][k]
v = p_table[v][k]
return p_table[u][0]
fixed_sum = 0
cnt_a = defaultdict(int)
cnt_b = defaultdict(int)
for _ in range(Q):
a = int(input[ptr]); ptr +=1
b = int(input[ptr]); ptr +=1
root_a = find(a)
root_b = find(b)
if root_a == root_b:
l = lca(a, b, root_a)
distance = depth[root_a][a] + depth[root_a][b] - 2 * depth[root_a][l]
fixed_sum += distance
else:
cnt_a[a] +=1
cnt_b[b] +=1
cnt = defaultdict(int)
for x in cnt_a:
cnt[x] += cnt_a[x]
for x in cnt_b:
cnt[x] += cnt_b[x]
variable_sum = 0
for root, nodes in components.items():
adj = component_adj[root]
cnt_nodes = {x: cnt[x] for x in nodes}
total_cnt = sum(cnt_nodes.values())
if total_cnt == 0:
continue
size = {}
sum_cnt = {}
stack = [(nodes[0], False)]
parent_dict = {nodes[0]: -1}
while stack:
u, visited = stack.pop()
if visited:
s = cnt_nodes[u]
sz = 1
for v in adj[u]:
if v != parent_dict[u]:
s += sum_cnt[v]
sz += size[v]
size[u] = sz
sum_cnt[u] = s
else:
stack.append( (u, True) )
for v in adj[u]:
if v != parent_dict.get(u, -1):
parent_dict[v] = u
stack.append( (v, False) )
current = nodes[0]
while True:
max_child_sum = 0
best_child = -1
for v in adj[current]:
if v != parent_dict.get(current, -1):
if sum_cnt.get(v, 0) > max_child_sum:
max_child_sum = sum_cnt[v]
best_child = v
if max_child_sum > total_cnt / 2:
current = best_child
else:
break
distance = {}
q = deque()
q.append(current)
visited = set()
visited.add(current)
distance[current] = 0
while q:
u = q.popleft()
for v in adj[u]:
if v not in visited:
visited.add(v)
distance[v] = distance[u] + 1
q.append(v)
sum_component = 0
for x in nodes:
sum_component += cnt[x] * distance[x]
variable_sum += sum_component
print(fixed_sum + variable_sum)
if __name__ == '__main__':
main()
gew1fw