結果
問題 |
No.2337 Equidistant
|
ユーザー |
![]() |
提出日時 | 2025-03-31 17:50:15 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,373 bytes |
コンパイル時間 | 325 ms |
コンパイル使用メモリ | 82,084 KB |
実行使用メモリ | 476,996 KB |
最終ジャッジ日時 | 2025-03-31 17:51:32 |
合計ジャッジ時間 | 22,716 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | WA * 21 TLE * 1 -- * 6 |
ソースコード
import sys from sys import stdin from collections import deque sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() idx = 0 N, Q = int(input[idx]), int(input[idx+1]) idx += 2 edges = [[] for _ in range(N+1)] for _ in range(N-1): a, b = int(input[idx]), int(input[idx+1]) edges[a].append(b) edges[b].append(a) idx +=2 queries = [] for _ in range(Q): s, t = int(input[idx]), int(input[idx+1]) queries.append((s, t)) idx +=2 LOG = 20 parent = [[-1]*(N+1) for _ in range(LOG)] depth = [0]*(N+1) size = [1]*(N+1) visited = [False]*(N+1) q = deque([1]) visited[1] = True parent[0][1] = -1 adj = [[] for _ in range(N+1)] while q: u = q.popleft() for v in edges[u]: if not visited[v] and v != parent[0][u]: visited[v] = True parent[0][v] = u depth[v] = depth[u] +1 q.append(v) adj[u].append(v) for k in range(1, LOG): for v in range(1, N+1): if parent[k-1][v] != -1: parent[k][v] = parent[k-1][parent[k-1][v]] def dfs(u, p): for v in adj[u]: if v != p: dfs(v, u) size[u] += size[v] dfs(1, -1) def lca(u, v): if depth[u] < depth[v]: u, v = v, u for k in reversed(range(LOG)): if depth[u] - (1 << k) >= depth[v]: u = parent[k][u] if u == v: return u for k in reversed(range(LOG)): if parent[k][u] != parent[k][v]: u = parent[k][u] v = parent[k][v] return parent[0][u] def get_kth_ancestor(u, k): if k <0: return -1 for i in range(LOG): if k & (1 << i): u = parent[i][u] if u == -1: return -1 return u def get_dist(u, v): a = lca(u, v) return depth[u] + depth[v] - 2 * depth[a] out = [] for s, t in queries: a_node = lca(s, t) a = depth[s] - depth[a_node] b = depth[t] - depth[a_node] D = a + b if D %2 !=0: out.append('0') continue k = D//2 if a >=k: m = get_kth_ancestor(s, k) else: rem = k -a m = get_kth_ancestor(t, (b - rem)) l = get_kth_ancestor(s, k-1) if k !=0 else -1 l_tmp = get_kth_ancestor(t, (a +b -k) -1) if (a +b -k) !=0 else -1 if k ==0: l = -1 s_part = l if k ==0: s_part = -1 if s_part ==-1: s_size =0 else: if parent[0][m] == s_part: s_size = N - size[m] else: s_size = size[s_part] if s_part in adj[m] else 0 t_part = l_tmp k_t = (a + b -k) if k_t ==0: t_part = -1 if t_part == -1: t_size =0 else: if parent[0][m] == t_part: t_size = N - size[m] else: t_size = size[t_part] if t_part in adj[m] else 0 ans = size[m] - s_size - t_size out.append(str(ans)) print('\n'.join(out)) if __name__ == '__main__': main()