結果
| 問題 |
No.2337 Equidistant
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:59:45 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,061 bytes |
| コンパイル時間 | 146 ms |
| コンパイル使用メモリ | 82,912 KB |
| 実行使用メモリ | 198,888 KB |
| 最終ジャッジ日時 | 2025-04-16 01:01:15 |
| 合計ジャッジ時間 | 22,293 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | WA * 20 TLE * 1 -- * 7 |
ソースコード
import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
idx = 0
N, Q = int(input[idx]), int(input[idx+1])
idx += 2
adj = [[] for _ in range(N+1)]
for _ in range(N-1):
a = int(input[idx])
b = int(input[idx+1])
adj[a].append(b)
adj[b].append(a)
idx += 2
LOG = 20
parent = [[0]*(LOG+1) for _ in range(N+1)]
depth = [0]*(N+1)
sz = [1]*(N+1)
visited = [False]*(N+1)
from collections import deque
q = deque([1])
visited[1] = True
parent[1][0] = 0
while q:
u = q.popleft()
for v in adj[u]:
if not visited[v] and v != parent[u][0]:
parent[v][0] = u
depth[v] = depth[u] + 1
visited[v] = True
q.append(v)
stack = [(1, False)]
while stack:
u, visited_flag = stack.pop()
if visited_flag:
for v in adj[u]:
if parent[v][0] == u:
sz[u] += sz[v]
continue
stack.append((u, True))
for v in adj[u]:
if parent[v][0] == u and not visited_flag:
stack.append((v, False))
for k in range(1, LOG+1):
for u in range(1, N+1):
parent[u][k] = parent[parent[u][k-1]][k-1]
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for k in range(LOG, -1, -1):
if depth[u] - (1 << k) >= depth[v]:
u = parent[u][k]
if u == v:
return u
for k in range(LOG, -1, -1):
if parent[u][k] != parent[v][k]:
u = parent[u][k]
v = parent[v][k]
return parent[u][0]
def get_kth_ancestor(u, k):
current = u
for i in range(LOG, -1, -1):
if k >= (1 << i):
current = parent[current][i]
k -= (1 << i)
return current if k == 0 else 0
def find_child(u, mid):
current = u
while current != 0 and parent[current][0] != mid:
current = parent[current][0]
return current if parent[current][0] == mid else None
for _ in range(Q):
S = int(input[idx])
T = int(input[idx+1])
idx += 2
if S == T:
print(0)
continue
L = lca(S, T)
a = depth[S] - depth[L]
b = depth[T] - depth[L]
D = a + b
if D % 2 != 0:
print(0)
continue
mid_dist = D // 2
if mid_dist <= a:
mid = get_kth_ancestor(S, mid_dist)
else:
mid = get_kth_ancestor(T, b - (mid_dist - a))
u = find_child(S, mid)
v = find_child(T, mid)
if u is None or v is None:
print(0)
continue
ans = sz[mid] - sz[u] - sz[v]
print(ans if ans >= 0 else 0)
if __name__ == "__main__":
main()
lam6er