## https://yukicoder.me/problems/no/2337 from collections import deque def calcurate_next_childs(N, next_nodes): parents = [-2] * N parents[0] = -1 total_childs = [0] * N next_childs = [{} for _ in range(N)] stack= deque() stack.append((0, 0)) while len(stack) > 0: v, index = stack.pop() while index < len(next_nodes[v]): w = next_nodes[v][index] if w == parents[v]: index += 1 continue parents[w] = v stack.append((v, index + 1)) stack.append((w, 0)) break if index == len(next_nodes[v]): p = parents[v] if p != -1: total_childs[p] += 1 + total_childs[v] next_childs[p][v] = 1 + total_childs[v] queue = deque() queue.append((0, -1, 0)) depth = [0] * N while len(queue) > 0: v, c, d = queue.popleft() depth[v] = d if parents[v] != -1: p = parents[v] total_childs[v] += 1 + c next_childs[v][p] = 1 + c for w in next_nodes[v]: if parents[v] == w: continue c = total_childs[v] - next_childs[v][w] queue.append((w, c, d + 1)) return next_childs, parents, depth def main(): N, Q = map(int, input().split()) next_nodes = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, input().split()) next_nodes[u - 1].append(v - 1) next_nodes[v - 1].append(u - 1) st = [] for _ in range(Q): s, t = map(int, input().split()) st.append((s -1 ,t - 1)) # 全方位木dp next_childs, parents, depth = calcurate_next_childs(N, next_nodes) # ダブリング k = 0 while (1 << k) < N: k += 1 max_k = k parents_list = [[-1] * N for _ in range(max_k + 1)] parents_list[0] = parents for k in range(1, max_k + 1): for i in range(N): p = parents_list[k - 1][i] if p != -1: parents_list[k][i] = parents_list[k - 1][p] def calc_lca(s, t): # depth[s] > depth[v]にしたい if depth[s] < depth[t]: t, s = s, t if depth[s] > depth[t]: d = depth[s] - depth[t] for k in reversed(range(max_k + 1)): if d >= (1 << k): s = parents_list[k][s] d -= (1 << k) if s == t: return s else: d = depth[s] for k in reversed(range(max_k + 1)): if d >= (1 << k): ps = parents_list[k][s] pt = parents_list[k][t] if ps != pt: s = ps t = pt d -= (1 << k) return parents_list[0][s] def calc_p(s, d): for k in reversed(range(max_k + 1)): if d >= (1 << k): s = parents_list[k][s] d -= (1 << k) return s # 本回答 for s, t in st: lca_v = calc_lca(s, t) dist = depth[s] + depth[t] - 2 * depth[lca_v] if dist % 2 != 0: print(0) else: dist_half = dist // 2 if dist_half == depth[s] - depth[lca_v]: s1 = calc_p(s, dist_half - 1) t1 = calc_p(t, dist_half - 1) ans = N - next_childs[lca_v][s1] - next_childs[lca_v][t1] print(ans) else: if depth[s] < depth[t]: p = calc_p(t, dist_half) p1 = calc_p(t, dist_half - 1) p2 = parents[p] ans = N - next_childs[p][p1] - next_childs[p][p2] else: p = calc_p(s, dist_half) p1 = calc_p(s, dist_half - 1) p2 = parents[p] ans = N - next_childs[p][p1] - next_childs[p][p2] print(ans) if __name__ == "__main__": main()