結果
| 問題 |
No.922 東北きりきざむたん
|
| ユーザー |
qwewe
|
| 提出日時 | 2025-04-24 12:27:15 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 6,259 bytes |
| コンパイル時間 | 621 ms |
| コンパイル使用メモリ | 82,088 KB |
| 実行使用メモリ | 78,376 KB |
| 最終ジャッジ日時 | 2025-04-24 12:28:35 |
| 合計ジャッジ時間 | 5,006 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 1 WA * 4 TLE * 1 -- * 20 |
ソースコード
import sys
from sys import stdin
from collections import deque, defaultdict
def main():
sys.setrecursionlimit(1 << 25)
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr]); ptr +=1
M = int(input[ptr]); ptr +=1
Q = int(input[ptr]); ptr +=1
# DSU to find components
class DSU:
def __init__(self, size):
self.parent = list(range(size+1))
self.rank = [0]*(size+1)
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
if self.rank[x_root] < self.rank[y_root]:
self.parent[x_root] = y_root
else:
self.parent[y_root] = x_root
if self.rank[x_root] == self.rank[y_root]:
self.rank[x_root] +=1
dsu = DSU(N)
edges = []
for _ in range(M):
u = int(input[ptr]); ptr +=1
v = int(input[ptr]); ptr +=1
dsu.union(u, v)
edges.append((u, v))
# Build adjacency list from the M edges
adj = [[] for _ in range(N+1)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
# Preprocess LCA for each component
max_log = 20
parent = [[-1]*(N+1) for _ in range(max_log)]
depth = [0]*(N+1)
root_map = [0]*(N+1)
visited = [False]*(N+1)
for u in range(1, N+1):
if not visited[u]:
root = u
queue = deque([root])
visited[root] = True
parent[0][root] = -1
depth[root] = 0
root_map[root] = root
while queue:
current = queue.popleft()
root_map[current] = root
for v in adj[current]:
if not visited[v] and dsu.find(v) == dsu.find(current):
visited[v] = True
parent[0][v] = current
depth[v] = depth[current] + 1
queue.append(v)
# Build binary lifting table
for k in range(1, max_log):
for u in range(1, N+1):
if parent[k-1][u] != -1:
parent[k][u] = parent[k-1][parent[k-1][u]]
else:
parent[k][u] = -1
# LCA function
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for k in range(max_log-1, -1, -1):
if depth[u] - (1 << k) >= depth[v]:
u = parent[k][u]
if u == v:
return u
for k in range(max_log-1, -1, -1):
if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
u = parent[k][u]
v = parent[k][v]
return parent[0][u]
# Distance function
def distance(u, v):
ancestor = lca(u, v)
return depth[u] + depth[v] - 2 * depth[ancestor]
# Read queries
same_total = 0
cross_queries = []
required = defaultdict(list) # root -> list of nodes
for _ in range(Q):
a = int(input[ptr]); ptr +=1
b = int(input[ptr]); ptr +=1
if dsu.find(a) == dsu.find(b):
same_total += distance(a, b)
else:
cross_queries.append( (a, b) )
ra = root_map[a]
rb = root_map[b]
required[ra].append(a)
required[rb].append(b)
# Process each component with required nodes
airport_dist = defaultdict(dict) # root -> {node: dist}
for r in required:
nodes = required[r]
component = set()
for node in nodes:
component.add(node)
# Build adjacency list for the component's tree
adj_comp = [[] for _ in range(N+1)]
for v in range(1, N+1):
if root_map[v] == r and parent[0][v] != -1:
u = parent[0][v]
adj_comp[u].append(v)
adj_comp[v].append(u)
# Marked nodes
marked = [False]*(N+1)
for node in nodes:
marked[node] = True
# Post-order to compute cnt and sum_dist_sub
cnt = [0]*(N+1)
sum_dist_sub = [0]*(N+1)
stack = [(r, False)]
while stack:
u, visited_flag = stack.pop()
if not visited_flag:
stack.append( (u, True) )
# Push children in reverse order to process in order
for v in reversed(adj_comp[u]):
if parent[0][v] == u and root_map[v] == r:
stack.append( (v, False) )
else:
current_cnt = 0
current_sum = 0
for v in adj_comp[u]:
if parent[0][v] == u and root_map[v] == r:
current_cnt += cnt[v]
current_sum += sum_dist_sub[v] + cnt[v]
if marked[u]:
current_cnt += 1
cnt[u] = current_cnt
sum_dist_sub[u] = current_sum
# Pre-order to compute sum_dist
sum_dist = [0]*(N+1)
sum_dist[r] = sum_dist_sub[r]
total_marked = len(nodes)
stack = [r]
while stack:
u = stack.pop()
for v in adj_comp[u]:
if parent[0][v] == u and root_map[v] == r:
sum_dist[v] = sum_dist[u] - cnt[v] + (total_marked - cnt[v])
stack.append(v)
# Find node with minimal sum_dist
min_sum = float('inf')
airport = r
for v in range(1, N+1):
if root_map[v] == r and sum_dist[v] < min_sum:
min_sum = sum_dist[v]
airport = v
# Precompute distances for required nodes
for node in nodes:
d = distance(node, airport)
airport_dist[r][node] = d
# Process cross_queries
cross_total = 0
for a, b in cross_queries:
ra = root_map[a]
rb = root_map[b]
da = airport_dist[ra][a]
db = airport_dist[rb][b]
cross_total += da + db
total = same_total + cross_total
print(total)
if __name__ == '__main__':
main()
qwewe