class JumpOnTree: def __init__(self, edges, root=0): self.n = len(edges) self.edges = edges self.root = root self.logn = (self.n - 1).bit_length() self.depth = [-1] * self.n self.depth[self.root] = 0 self.parent = [[-1] * self.n for _ in range(self.logn)] self.dfs() self.doubling() def dfs(self): stack = [self.root] while stack: u = stack.pop() for v in self.edges[u]: if self.depth[v] == -1: self.depth[v] = self.depth[u] + 1 self.parent[0][v] = u stack.append(v) def doubling(self): for i in range(1, self.logn): for u in range(self.n): p = self.parent[i - 1][u] if p != -1: self.parent[i][u] = self.parent[i - 1][p] def lca(self, u, v): du = self.depth[u] dv = self.depth[v] if du > dv: du, dv = dv, du u, v = v, u d = dv - du i = 0 while d > 0: if d & 1: v = self.parent[i][v] d >>= 1 i += 1 if u == v: return u logn = (du - 1).bit_length() for i in range(logn - 1, -1, -1): pu = self.parent[i][u] pv = self.parent[i][v] if pu != pv: u = pu v = pv return self.parent[0][u] def dist(self,u,v): L=self.lca(u,v) return self.depth[u]+self.depth[v]-self.depth[L]*2 def jump(self, u, v, k): if k == 0: return u p = self.lca(u, v) d1 = self.depth[u] - self.depth[p] d2 = self.depth[v] - self.depth[p] if d1 + d2 < k: return -1 if k <= d1: d = k else: u = v d = d1 + d2 - k i = 0 while d > 0: if d & 1: u = self.parent[i][u] d >>= 1 i += 1 return u from sys import stdin input=lambda :stdin.readline()[:-1] n,q=map(int,input().split()) edge=[[] for i in range(n)] for i in range(n-1): a,b=map(lambda x:int(x)-1,input().split()) edge[a].append(b) edge[b].append(a) JT=JumpOnTree(edge) query=[[] for i in range(n)] for i in range(q): s,t=map(lambda x:int(x)-1,input().split()) d=JT.dist(s,t) if d%2: continue mid=JT.jump(s,t,d//2) ng1=JT.jump(mid,s,1) ng2=JT.jump(mid,t,1) query[mid].append((i,ng1,ng2)) def rerooting(query): dp=[[E]*len(edge[v]) for v in range(n)] # dfs1 memo=[E]*n for v in order[::-1]: res=E for i in range(len(edge[v])): if edge[v][i]==par[v]: continue dp[v][i]=memo[edge[v][i]] res=merge(res,f(dp[v][i],edge[v][i])) memo[v]=g(res,v) # dfs2 memo2=[E]*n for v in order: for i in range(len(edge[v])): if edge[v][i]==par[v]: dp[v][i]=memo2[v] s=len(edge[v]) cumR=[E]*(s+1) cumR[s]=E for i in range(s,0,-1): cumR[i-1]=merge(cumR[i],f(dp[v][i-1],edge[v][i-1])) cumL=E for i in range(s): if edge[v][i]!=par[v]: val=merge(cumL,cumR[i+1]) memo2[edge[v][i]]=g(val,v) cumL=merge(cumL,f(dp[v][i],edge[v][i])) ans=[0]*q for v in range(n): if query[v]: dic={} for i in range(len(edge[v])): dic[edge[v][i]]=dp[v][i] for i,ng1,ng2 in query[v]: res=n res-=dic[ng1] res-=dic[ng2] ans[i]=res return ans E=0 def f(res,v): return res def g(res,v): return res+1 def merge(a,b): return a+b def calc_ans(res,v): return g(res,v) # make order table # root = 0 from collections import deque order=[] par=[-1]*n todo=deque([0]) while todo: v=todo.popleft() order.append(v) for u in edge[v]: if u!=par[v]: par[u]=v todo.append(u) ans=rerooting(query) print(*ans,sep='\n')