結果
| 問題 |
No.2337 Equidistant
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:50:15 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,373 bytes |
| コンパイル時間 | 325 ms |
| コンパイル使用メモリ | 82,084 KB |
| 実行使用メモリ | 476,996 KB |
| 最終ジャッジ日時 | 2025-03-31 17:51:32 |
| 合計ジャッジ時間 | 22,716 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | WA * 21 TLE * 1 -- * 6 |
ソースコード
import sys
from sys import stdin
from collections import deque
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
idx = 0
N, Q = int(input[idx]), int(input[idx+1])
idx += 2
edges = [[] for _ in range(N+1)]
for _ in range(N-1):
a, b = int(input[idx]), int(input[idx+1])
edges[a].append(b)
edges[b].append(a)
idx +=2
queries = []
for _ in range(Q):
s, t = int(input[idx]), int(input[idx+1])
queries.append((s, t))
idx +=2
LOG = 20
parent = [[-1]*(N+1) for _ in range(LOG)]
depth = [0]*(N+1)
size = [1]*(N+1)
visited = [False]*(N+1)
q = deque([1])
visited[1] = True
parent[0][1] = -1
adj = [[] for _ in range(N+1)]
while q:
u = q.popleft()
for v in edges[u]:
if not visited[v] and v != parent[0][u]:
visited[v] = True
parent[0][v] = u
depth[v] = depth[u] +1
q.append(v)
adj[u].append(v)
for k in range(1, LOG):
for v in range(1, N+1):
if parent[k-1][v] != -1:
parent[k][v] = parent[k-1][parent[k-1][v]]
def dfs(u, p):
for v in adj[u]:
if v != p:
dfs(v, u)
size[u] += size[v]
dfs(1, -1)
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for k in reversed(range(LOG)):
if depth[u] - (1 << k) >= depth[v]:
u = parent[k][u]
if u == v:
return u
for k in reversed(range(LOG)):
if parent[k][u] != parent[k][v]:
u = parent[k][u]
v = parent[k][v]
return parent[0][u]
def get_kth_ancestor(u, k):
if k <0:
return -1
for i in range(LOG):
if k & (1 << i):
u = parent[i][u]
if u == -1:
return -1
return u
def get_dist(u, v):
a = lca(u, v)
return depth[u] + depth[v] - 2 * depth[a]
out = []
for s, t in queries:
a_node = lca(s, t)
a = depth[s] - depth[a_node]
b = depth[t] - depth[a_node]
D = a + b
if D %2 !=0:
out.append('0')
continue
k = D//2
if a >=k:
m = get_kth_ancestor(s, k)
else:
rem = k -a
m = get_kth_ancestor(t, (b - rem))
l = get_kth_ancestor(s, k-1) if k !=0 else -1
l_tmp = get_kth_ancestor(t, (a +b -k) -1) if (a +b -k) !=0 else -1
if k ==0:
l = -1
s_part = l
if k ==0:
s_part = -1
if s_part ==-1:
s_size =0
else:
if parent[0][m] == s_part:
s_size = N - size[m]
else:
s_size = size[s_part] if s_part in adj[m] else 0
t_part = l_tmp
k_t = (a + b -k)
if k_t ==0:
t_part = -1
if t_part == -1:
t_size =0
else:
if parent[0][m] == t_part:
t_size = N - size[m]
else:
t_size = size[t_part] if t_part in adj[m] else 0
ans = size[m] - s_size - t_size
out.append(str(ans))
print('\n'.join(out))
if __name__ == '__main__':
main()
lam6er