結果

問題 No.2337 Equidistant
ユーザー ShirotsumeShirotsume
提出日時 2023-06-02 22:01:57
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,722 ms / 4,000 ms
コード長 3,588 bytes
コンパイル時間 160 ms
コンパイル使用メモリ 82,852 KB
実行使用メモリ 192,908 KB
最終ジャッジ日時 2024-06-08 23:07:45
合計ジャッジ時間 32,325 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 48 ms
55,808 KB
testcase_01 AC 48 ms
56,064 KB
testcase_02 AC 44 ms
56,064 KB
testcase_03 AC 46 ms
55,808 KB
testcase_04 AC 46 ms
56,192 KB
testcase_05 AC 44 ms
56,192 KB
testcase_06 AC 166 ms
79,720 KB
testcase_07 AC 185 ms
80,384 KB
testcase_08 AC 168 ms
79,736 KB
testcase_09 AC 170 ms
79,956 KB
testcase_10 AC 176 ms
80,604 KB
testcase_11 AC 1,478 ms
162,444 KB
testcase_12 AC 1,716 ms
167,776 KB
testcase_13 AC 1,471 ms
162,112 KB
testcase_14 AC 1,811 ms
162,236 KB
testcase_15 AC 1,487 ms
155,632 KB
testcase_16 AC 1,448 ms
161,828 KB
testcase_17 AC 1,478 ms
162,008 KB
testcase_18 AC 1,445 ms
155,636 KB
testcase_19 AC 1,752 ms
161,416 KB
testcase_20 AC 1,798 ms
169,176 KB
testcase_21 AC 1,360 ms
186,088 KB
testcase_22 AC 840 ms
173,492 KB
testcase_23 AC 1,299 ms
158,224 KB
testcase_24 AC 2,219 ms
192,908 KB
testcase_25 AC 1,332 ms
158,560 KB
testcase_26 AC 2,722 ms
192,908 KB
testcase_27 AC 1,466 ms
159,380 KB
testcase_28 AC 1,422 ms
158,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys, time, random
from collections import deque, Counter, defaultdict
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 63 - 1
mod = 998244353

import sys


class LcaDoubling:
    """
    links[v] = { u1, u2, ... }  (u:隣接頂点, 親は含まない)
    というグラフ情報から、ダブリングによるLCAを構築。
    任意の2頂点のLCAを取得できるようにする
    """

    def __init__(self, n, links, root=0):
        self.depths = [-1] * n
        prev_ancestors = self._init_dfs(n, links, root)
        self.ancestors = [prev_ancestors]
        max_depth = max(self.depths)
        d = 1
        while d < max_depth:
            next_ancestors = [prev_ancestors[p] for p in prev_ancestors]
            self.ancestors.append(next_ancestors)
            d <<= 1
            prev_ancestors = next_ancestors

    def _init_dfs(self, n, links, root):
        q = [root]
        direct_ancestors = [-1] * (n + 1)  # 頂点数より1個長くし、存在しないことを-1で表す。末尾(-1)要素は常に-1
        self.depths[root] = 0
        while q:
            u = q.pop()
            for v in links[u]:
                if self.depths[v] != -1:
                    continue
                direct_ancestors[v] = u
                self.depths[v] = self.depths[u] + 1
                links[v].discard(u)
                q.append(v)
        return direct_ancestors

    def get_lca(self, u, v):
        du, dv = self.depths[u], self.depths[v]
        if du > dv:
            u, v = v, u
            du, dv = dv, du
        tu = u
        tv = self.upstream(v, dv - du)
        if u == tv:
            return u
        for k in range(du.bit_length() - 1, -1, -1):
            mu = self.ancestors[k][tu]
            mv = self.ancestors[k][tv]
            if mu != mv:
                tu = mu
                tv = mv
        lca = self.ancestors[0][tu]
        assert lca == self.ancestors[0][tv]
        return lca

    def upstream(self, v, k):
        i = 0
        while k:
            if k & 1:
                v = self.ancestors[i][v]
            k >>= 1
            i += 1
        return v

    def jump(self, u: int, v: int, i: int) -> int:
        """ uからvに向けて進んだパスのi番目(0-indexed)の頂点を得る。パス長が足りない場合は-1 """
        c = self.get_lca(u, v)
        du = self.depths[u]
        dv = self.depths[v]
        dc = self.depths[c]

        path_len = du - dc + dv - dc
        if path_len < i:
            return -1

        if du - dc >= i:
            return self.upstream(u, i)

        return self.upstream(v, path_len - i)
n, q = mi()

graph = [set() for _ in range(n)]

for _ in range(n - 1):
    u, v = mi()
    u -= 1; v -= 1
    graph[u].add(v)
    graph[v].add(u)

L = LcaDoubling(n, graph)
def size_of_subtree(s, t):
    if L.depths[s] < L.depths[t]:
        return subt[t]
    else:
        return n - subt[s]
p = list(range(n))
p.sort(key = lambda x: L.depths[x], reverse=True)

subt = [0] * n

for v in p:
    for to in graph[v]:
        if L.depths[to] > L.depths[v]:
            subt[v] += subt[to]
    subt[v] += 1

for _ in range(q):
    s, t = mi()
    s -= 1; t -= 1
    x = L.get_lca(s, t)
    l = L.depths[s] + L.depths[t] - 2 * L.depths[x]
    if l % 2 == 1:
        print(0)
    else:
        u = L.jump(s, t, l // 2)
        s1 = L.jump(u, s, 1)
        t1 = L.jump(u, t, 1)
        ans = n - size_of_subtree(u, s1) - size_of_subtree(u, t1)
        print(ans)


0