結果
| 問題 | No.2337 Equidistant |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-18 00:36:57 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 4,145 bytes |
| 記録 | |
| コンパイル時間 | 427 ms |
| コンパイル使用メモリ | 82,288 KB |
| 実行使用メモリ | 243,472 KB |
| 最終ジャッジ日時 | 2026-01-18 00:37:28 |
| 合計ジャッジ時間 | 30,086 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 5 WA * 1 RE * 22 |
ソースコード
## https://yukicoder.me/problems/no/2337
from collections import deque
def calcurate_next_childs(N, next_nodes):
parents = [-2] * N
parents[0] = -1
total_childs = [0] * N
next_childs = [{} for _ in range(N)]
stack= deque()
stack.append((0, 0))
while len(stack) > 0:
v, index = stack.pop()
while index < len(next_nodes[v]):
w = next_nodes[v][index]
if w == parents[v]:
index += 1
continue
parents[w] = v
stack.append((v, index + 1))
stack.append((w, 0))
break
if index == len(next_nodes[v]):
p = parents[v]
if p != -1:
total_childs[p] += 1 + total_childs[v]
next_childs[p][v] = 1 + total_childs[v]
queue = deque()
queue.append((0, -1, 0))
depth = [0] * N
while len(queue) > 0:
v, c, d = queue.popleft()
depth[v] = d
if parents[v] != -1:
p = parents[v]
total_childs[v] += 1 + c
next_childs[v][p] = 1 + c
for w in next_nodes[v]:
if parents[v] == w:
continue
c = total_childs[v] - next_childs[v][w]
queue.append((w, c, d + 1))
return next_childs, parents, depth
def main():
N, Q = map(int, input().split())
next_nodes = [[] for _ in range(N)]
for _ in range(N - 1):
u, v = map(int, input().split())
next_nodes[u - 1].append(v - 1)
next_nodes[v - 1].append(u - 1)
st = []
for _ in range(Q):
s, t = map(int, input().split())
st.append((s -1 ,t - 1))
# 全方位木dp
next_childs, parents, depth = calcurate_next_childs(N, next_nodes)
# ダブリング
k = 0
while (1 << k) < N:
k += 1
max_k = k
parents_list = [[-1] * N for _ in range(max_k + 1)]
parents_list[0] = parents
for k in range(1, max_k + 1):
for i in range(N):
p = parents_list[k - 1][i]
if p != -1:
parents_list[k][i] = parents_list[k - 1][p]
def calc_lca(s, t):
# depth[s] > depth[v]にしたい
if depth[s] < depth[t]:
t, s = s, t
if depth[s] > depth[t]:
d = depth[s] - depth[t]
for k in reversed(range(max_k + 1)):
if d >= (1 << k):
s = parents_list[k][s]
d -= (1 << k)
if s == t:
return s
else:
d = depth[s]
for k in reversed(range(max_k + 1)):
if d >= (1 << k):
ps = parents_list[k][s]
pt = parents_list[k][t]
if ps != pt:
s = ps
t = pt
d -= (1 << k)
return parents_list[0][s]
def calc_p(s, d):
for k in reversed(range(max_k + 1)):
if d >= (1 << k):
s = parents_list[k][s]
d -= (1 << k)
return s
# 本回答
for s, t in st:
lca_v = calc_lca(s, t)
dist = depth[s] + depth[t] - 2 * depth[lca_v]
if dist % 2 != 0:
print(0)
else:
dist_half = dist // 2
if dist_half == depth[s]:
s1 = calc_p(s, dist_half - 1)
t1 = calc_p(t, dist_half - 1)
ans = N - next_childs[lca_v][s1] - next_childs[lca_v][t1]
print(ans)
else:
if depth[s] < depth[t]:
p = calc_p(t, dist_half)
p1 = calc_p(t, dist_half - 1)
p2 = parents[p]
ans = N - next_childs[p][p1] - next_childs[p][p2]
else:
p = calc_p(s, dist_half)
p1 = calc_p(s, dist_half - 1)
p2 = parents[p]
ans = N - next_childs[p][p1] - next_childs[p][p2]
print(ans)
if __name__ == "__main__":
main()