結果

問題 No.2337 Equidistant
ユーザー ShirotsumeShirotsume
提出日時 2023-06-02 21:57:19
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,346 bytes
コンパイル時間 225 ms
コンパイル使用メモリ 82,396 KB
実行使用メモリ 232,488 KB
最終ジャッジ日時 2024-06-08 22:59:10
合計ジャッジ時間 8,685 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 48 ms
62,080 KB
testcase_01 AC 45 ms
56,832 KB
testcase_02 AC 46 ms
56,832 KB
testcase_03 AC 44 ms
56,576 KB
testcase_04 AC 48 ms
56,832 KB
testcase_05 AC 51 ms
56,576 KB
testcase_06 AC 307 ms
82,356 KB
testcase_07 AC 317 ms
81,708 KB
testcase_08 AC 303 ms
82,116 KB
testcase_09 AC 329 ms
82,140 KB
testcase_10 AC 323 ms
82,420 KB
testcase_11 TLE -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys, time, random
from collections import deque, Counter, defaultdict
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 63 - 1
mod = 998244353

class segtree():
    def __init__(self,V,OP,E):
        self.n=len(V)
        self.op=OP
        self.e=E
        self.log=(self.n-1).bit_length()
        self.size=1<<self.log
        self.d=[E for i in range(2*self.size)]
        for i in range(self.n):
            self.d[self.size+i]=V[i]
        for i in range(self.size-1,0,-1):
            self.update(i)

    def prod(self,l,r):
        assert 0<=l and l<=r and r<=self.n
        sml=self.e
        smr=self.e
        l+=self.size
        r+=self.size
        while(l<r):
            if (l&1):
                sml=self.op(sml,self.d[l])
                l+=1
            if (r&1):
                smr=self.op(self.d[r-1],smr)
                r-=1
            l>>=1
            r>>=1
        return self.op(sml,smr)
  
    def update(self,k):
        self.d[k]=self.op(self.d[2*k],self.d[2*k+1])

def EulerTour(s, graph):
    n = len(graph)
    visit = [False] * n
    visit[s] = True
    q = [s]
    ret = []
    while q:
        now = q.pop()
        if now >= 0:
            ret.append(now)
            for to in graph[now][::-1]:
                if not visit[to]:
                    visit[to] = True
                    q.append(~now)
                    q.append(to)
        else:
            ret.append(~now)
    
    return ret

def CalcDepth(s, graph):
    INF = 2 ** 63 - 1
    from collections import deque
    n = len(graph)
    depth = [INF] * n
    depth[s] = 0
    q = deque()
    q.append(s)
    while q:
        now = q.popleft()
        for to in graph[now]:
            if depth[to] == INF:
                depth[to] = depth[now] + 1
                q.append(to)
    return depth


class LCA():
    def __init__(self, graph):
        self.INF = 2 ** 63 - 1
        self.graph = graph
        self.N = len(self.graph)
        self.ET = EulerTour(0, self.graph)
        self.depth = CalcDepth(0, graph)
        self.disc = [-1] * (self.N)
        self.fin = [-1] * (self.N)
        for i, v in enumerate(self.ET):
            if self.disc[v] == -1:
                self.disc[v] = i
            self.fin[v] = i
        self.S = segtree([(self.ET[i], self.depth[self.ET[i]]) for i in range(len(self.ET))], lambda x, y: x if x[1] <= y[1] else y, (-1, self.INF))
    
    def lca(self, u, v):
        st = min(self.disc[u], self.disc[v])
        en = max(self.fin[u], self.fin[v]) + 1
        ver, _ = self.S.prod(st, en)
        return ver
    
    def dist(self, u, v):
        a = self.lca(u, v)
        return self.depth[u] + self.depth[v] - 2 * self.depth[a]
    
n, q = mi()



graph = [[] for _ in range(n)]

for _ in range(n - 1):
    u, v = mi()
    u -= 1; v -= 1
    graph[u].append(v)
    graph[v].append(u)

d = CalcDepth(0, graph)

p = list(range(n))
p.sort(key = lambda x: d[x], reverse=True)

subt = [0] * n

for v in p:
    for to in graph[v]:
        if d[to] > d[v]:
            subt[v] += subt[to]
    subt[v] += 1

L = LCA(graph)
p = [0] * n
for i in range(n):
    for to in graph[i]:
        if L.lca(i, to) == to:
            p[i] = to
db = [[0] * n for _ in range(19)]

for i in range(n):
    db[0][i] = p[i]

for j in range(18):
    for i in range(n):
        db[j + 1][i] = db[j][db[j][i]]

def jump(s, t, i):
    if L.dist(s, t) < i:
        return (-1)
    else:
        p = L.lca(s, t)
        if L.dist(s, p) >= i:
            now = s
            for bit in range(18):
                if i % 2:
                    now = db[bit][now]
                i //= 2
            return (now)
        else:
            i = L.dist(s, t) - i
            now = t
            for bit in range(18):
                if i % 2:
                    now = db[bit][now]
                i //= 2
            return (now)

def size_of_subtree(s, t):
    if d[s] < d[t]:
        return subt[t]
    else:
        return n - subt[s]

for _ in range(q):
    s, t = mi()
    s -= 1; t -= 1
    l = L.dist(s, t)
    if l % 2 == 1:
        print(0)
    else:
        u = jump(s, t, l // 2)
        s1 = jump(u, s, 1)
        t1 = jump(u, t, 1)
        ans = n - size_of_subtree(u, s1) - size_of_subtree(u, t1)
        print(ans)


0