結果

問題 No.2337 Equidistant
ユーザー ShirotsumeShirotsume
提出日時 2023-06-02 21:36:19
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,332 bytes
コンパイル時間 219 ms
コンパイル使用メモリ 82,176 KB
実行使用メモリ 156,544 KB
最終ジャッジ日時 2024-06-08 22:28:21
合計ジャッジ時間 30,343 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 50 ms
56,320 KB
testcase_01 WA -
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys, time, random
from collections import deque, Counter, defaultdict
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 63 - 1
mod = 998244353

class segtree():
    def __init__(self,V,OP,E):
        self.n=len(V)
        self.op=OP
        self.e=E
        self.log=(self.n-1).bit_length()
        self.size=1<<self.log
        self.d=[E for i in range(2*self.size)]
        for i in range(self.n):
            self.d[self.size+i]=V[i]
        for i in range(self.size-1,0,-1):
            self.update(i)

    def prod(self,l,r):
        assert 0<=l and l<=r and r<=self.n
        sml=self.e
        smr=self.e
        l+=self.size
        r+=self.size
        while(l<r):
            if (l&1):
                sml=self.op(sml,self.d[l])
                l+=1
            if (r&1):
                smr=self.op(self.d[r-1],smr)
                r-=1
            l>>=1
            r>>=1
        return self.op(sml,smr)
  
    def update(self,k):
        self.d[k]=self.op(self.d[2*k],self.d[2*k+1])

def EulerTour(s, graph):
    n = len(graph)
    visit = [False] * n
    visit[s] = True
    q = [s]
    ret = []
    while q:
        now = q.pop()
        if now >= 0:
            ret.append(now)
            for to in graph[now][::-1]:
                if not visit[to]:
                    visit[to] = True
                    q.append(~now)
                    q.append(to)
        else:
            ret.append(~now)
    
    return ret

def CalcDepth(s, graph):
    INF = 2 ** 63 - 1
    from collections import deque
    n = len(graph)
    depth = [INF] * n
    depth[s] = 0
    q = deque()
    q.append(s)
    while q:
        now = q.popleft()
        for to in graph[now]:
            if depth[to] == INF:
                depth[to] = depth[now] + 1
                q.append(to)
    return depth


class LCA():
    def __init__(self, graph):
        self.INF = 2 ** 63 - 1
        self.graph = graph
        self.N = len(self.graph)
        self.ET = EulerTour(0, self.graph)
        self.depth = CalcDepth(0, graph)
        self.disc = [-1] * (self.N)
        self.fin = [-1] * (self.N)
        for i, v in enumerate(self.ET):
            if self.disc[v] == -1:
                self.disc[v] = i
            self.fin[v] = i
        self.S = segtree([(self.ET[i], self.depth[self.ET[i]]) for i in range(len(self.ET))], lambda x, y: x if x[1] <= y[1] else y, (-1, self.INF))
    
    def lca(self, u, v):
        st = min(self.disc[u], self.disc[v])
        en = max(self.fin[u], self.fin[v]) + 1
        ver, _ = self.S.prod(st, en)
        return ver
    
    def dist(self, u, v):
        a = self.lca(u, v)
        return self.depth[u] + self.depth[v] - 2 * self.depth[a]
    
n, q = mi()



graph = [[] for _ in range(n)]

for _ in range(n - 1):
    u, v = mi()
    u -= 1; v -= 1
    graph[u].append(v)
    graph[v].append(u)

d = CalcDepth(0, graph)

p = list(range(n))
p.sort(key = lambda x: d[x], reverse=True)

subt = [0] * n

for v in p:
    for to in graph[v]:
        if d[to] > d[v]:
            subt[v] += subt[to]
    subt[v] += 1
L = LCA(graph)

for _ in range(q):
    s, t = mi()
    s -= 1; t -= 1
    if L.dist(s, t) % 2 == 1:
        print(0)
    else:
        l = L.lca(s, t)
        print(n - subt[l] + 1)
0