結果

問題 No.2337 Equidistant
ユーザー とりゐとりゐ
提出日時 2023-06-02 22:09:23
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,939 ms / 4,000 ms
コード長 4,018 bytes
コンパイル時間 1,236 ms
コンパイル使用メモリ 87,124 KB
実行使用メモリ 244,872 KB
最終ジャッジ日時 2023-08-28 03:43:51
合計ジャッジ時間 37,365 ms
ジャッジサーバーID
(参考情報)
judge12 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 94 ms
71,568 KB
testcase_01 AC 97 ms
72,208 KB
testcase_02 AC 96 ms
72,420 KB
testcase_03 AC 96 ms
72,364 KB
testcase_04 AC 95 ms
72,048 KB
testcase_05 AC 95 ms
72,372 KB
testcase_06 AC 231 ms
81,084 KB
testcase_07 AC 255 ms
83,988 KB
testcase_08 AC 225 ms
81,700 KB
testcase_09 AC 233 ms
80,888 KB
testcase_10 AC 239 ms
81,420 KB
testcase_11 AC 1,819 ms
194,448 KB
testcase_12 AC 1,758 ms
194,652 KB
testcase_13 AC 1,735 ms
193,760 KB
testcase_14 AC 1,760 ms
193,532 KB
testcase_15 AC 1,688 ms
196,376 KB
testcase_16 AC 1,742 ms
196,484 KB
testcase_17 AC 1,627 ms
195,848 KB
testcase_18 AC 1,768 ms
198,344 KB
testcase_19 AC 1,717 ms
195,532 KB
testcase_20 AC 1,738 ms
194,156 KB
testcase_21 AC 1,675 ms
192,432 KB
testcase_22 AC 1,116 ms
244,872 KB
testcase_23 AC 1,403 ms
200,272 KB
testcase_24 AC 2,309 ms
190,276 KB
testcase_25 AC 1,470 ms
198,216 KB
testcase_26 AC 2,939 ms
203,644 KB
testcase_27 AC 1,691 ms
215,980 KB
testcase_28 AC 1,714 ms
214,564 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class JumpOnTree:
    def __init__(self, edges, root=0):
        self.n = len(edges)
        self.edges = edges
        self.root = root
        self.logn = (self.n - 1).bit_length()
        self.depth = [-1] * self.n
        self.depth[self.root] = 0
        self.parent = [[-1] * self.n for _ in range(self.logn)]
        self.dfs()
        self.doubling()
    
    def dfs(self):
        stack = [self.root]
        while stack:
            u = stack.pop()
            for v in self.edges[u]:
                if self.depth[v] == -1:
                    self.depth[v] = self.depth[u] + 1
                    self.parent[0][v] = u
                    stack.append(v)

    def doubling(self):
        for i in range(1, self.logn):
            for u in range(self.n):
                p = self.parent[i - 1][u]
                if p != -1:
                    self.parent[i][u] = self.parent[i - 1][p]
    
    def lca(self, u, v):
        du = self.depth[u]
        dv = self.depth[v]
        if du > dv:
            du, dv = dv, du
            u, v = v, u
        
        d = dv - du
        i = 0
        while d > 0:
            if d & 1:
                v = self.parent[i][v]
            d >>= 1
            i += 1
        if u == v:
            return u
        
        logn = (du - 1).bit_length()
        for i in range(logn - 1, -1, -1):
            pu = self.parent[i][u]
            pv = self.parent[i][v]
            if pu != pv:
                u = pu
                v = pv
        return self.parent[0][u]
    
    def dist(self,u,v):
      L=self.lca(u,v)
      return self.depth[u]+self.depth[v]-self.depth[L]*2

    def jump(self, u, v, k):
        if k == 0:
            return u
        p = self.lca(u, v)
        d1 = self.depth[u] - self.depth[p]
        d2 = self.depth[v] - self.depth[p]
        if d1 + d2 < k:
            return -1
        if k <= d1:
            d = k
        else:
            u = v
            d = d1 + d2 - k
        i = 0
        while d > 0:
            if d & 1:
                u = self.parent[i][u]
            d >>= 1
            i += 1
        return u





from sys import stdin
input=lambda :stdin.readline()[:-1]

n,q=map(int,input().split())
edge=[[] for i in range(n)]
for i in range(n-1):
  a,b=map(lambda x:int(x)-1,input().split())
  edge[a].append(b)
  edge[b].append(a)

JT=JumpOnTree(edge)
query=[[] for i in range(n)]
for i in range(q):
  s,t=map(lambda x:int(x)-1,input().split())
  d=JT.dist(s,t)
  if d%2:
    continue
  mid=JT.jump(s,t,d//2)
  ng1=JT.jump(mid,s,1)
  ng2=JT.jump(mid,t,1)
  query[mid].append((i,ng1,ng2))







def rerooting(query):  
  dp=[[E]*len(edge[v]) for v in range(n)]
  
  # dfs1
  memo=[E]*n
  for v in order[::-1]:
    res=E
    for i in range(len(edge[v])):
      if edge[v][i]==par[v]:
        continue
      dp[v][i]=memo[edge[v][i]]
      res=merge(res,f(dp[v][i],edge[v][i]))
    memo[v]=g(res,v)
  
  # dfs2
  memo2=[E]*n
  for v in order:
    for i in range(len(edge[v])):
      if edge[v][i]==par[v]:
        dp[v][i]=memo2[v]
    
    s=len(edge[v])
    cumR=[E]*(s+1)
    cumR[s]=E
    for i in range(s,0,-1):
      cumR[i-1]=merge(cumR[i],f(dp[v][i-1],edge[v][i-1]))
      
    cumL=E
    for i in range(s):
      if edge[v][i]!=par[v]:
        val=merge(cumL,cumR[i+1])
        memo2[edge[v][i]]=g(val,v)
      cumL=merge(cumL,f(dp[v][i],edge[v][i]))
  
  ans=[0]*q
  for v in range(n):
    if query[v]:
      dic={}
      for i in range(len(edge[v])):
        dic[edge[v][i]]=dp[v][i]
      for i,ng1,ng2 in query[v]:
        res=n
        res-=dic[ng1]
        res-=dic[ng2]
        ans[i]=res
  return ans


E=0

def f(res,v):
  return res

def g(res,v):
  return res+1
 
def merge(a,b):
  return a+b


def calc_ans(res,v):
  return g(res,v)


# make order table
# root = 0

from collections import deque
order=[]
par=[-1]*n
todo=deque([0])
while todo:
  v=todo.popleft()
  order.append(v)
  for u in edge[v]:
    if u!=par[v]:
      par[u]=v
      todo.append(u)

ans=rerooting(query)
print(*ans,sep='\n')
0