結果

問題 No.2337 Equidistant
ユーザー FromBooskaFromBooska
提出日時 2023-06-02 22:19:29
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,098 bytes
コンパイル時間 274 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 80,736 KB
最終ジャッジ日時 2024-06-08 23:37:59
合計ジャッジ時間 9,790 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 43 ms
56,696 KB
testcase_01 AC 46 ms
54,528 KB
testcase_02 AC 45 ms
54,528 KB
testcase_03 AC 43 ms
54,528 KB
testcase_04 AC 46 ms
54,528 KB
testcase_05 AC 42 ms
54,784 KB
testcase_06 AC 499 ms
79,680 KB
testcase_07 AC 502 ms
79,832 KB
testcase_08 AC 477 ms
80,180 KB
testcase_09 AC 494 ms
79,256 KB
testcase_10 AC 525 ms
80,736 KB
testcase_11 TLE -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

# オイラーツアーかなー
# st間の距離は求まる
# それが奇数なら0
# しかし偶数のときはそこから出る枝の数が必要
# 毎クエリで距離同じ頂点をカウントしたら間に合わないよな
# 一応BFSしてみる

# 木の頂点間の距離を高速に求める
# LCAパッケージ
# 使い方はこの例を見るしかない
# dummy, depth = LCA.query(a, b)で共通祖先からのdepth
# LCA.tour[LCA.in_time[a]][1]でルート1からのaまでの距離

# 違うLCAでやったらTLEしたのでこちらの人のを借用
# https://atcoder.jp/contests/abc014/submissions/37274146

# LCA
class LCA():
    def __init__(self, N, graph, root):
        self.tour, self.in_time = self.EulerTour(N, graph, root)
        self.seg = self.SegTree(self.tour)

    def query(self, u, v):
        uu = self.in_time[u]
        vv = self.in_time[v]
        if vv < uu:
            uu, vv = vv, uu
        position, depth = self.seg.query(uu, vv+1)
        return position, depth

    def EulerTour(self, N, graph, root):
        used = [False]*N
        q = [~root, root]
        tour = []
        in_time = [-1]*N
        time = -1
        d = -1
        while q:
            u = q.pop()
            if u < 0:
                time += 1
                if -N <= u:
                    d -= 1
                tour.append((u, d))
            if u >= 0:
                time += 1
                if in_time[u] < 0:
                    in_time[u] = time
                d += 1
                tour.append((u, d))
                flg = False
                for v in graph[u]:
                    if used[v]:
                        continue
                    used[v] = True
                    if flg:
                        q.append(~u-N)
                    q.append(~v)
                    q.append(v)
                    flg = True
        return tour, in_time

    class SegTree:
        def segfunc(self, x, y):
            if x[1] < y[1]:
                return x
            return y

        def __init__(self, init_val):
            n = len(init_val)
            self.ide_ele = (10**10, 10**10)
            self.num = 1 << (n-1).bit_length()
            self.tree = [self.ide_ele]*2*self.num
            for i in range(n):
                self.tree[self.num+i] = init_val[i]
            for i in range(self.num-1, 0, -1):
                self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])

        def query(self, l, r):
            res = self.ide_ele
            l += self.num
            r += self.num
            while l < r:
                if l & 1:
                    res = self.segfunc(res, self.tree[l])
                    l += 1
                if r & 1:
                    res = self.segfunc(res, self.tree[r-1])
                l >>= 1
                r >>= 1
            return res

from collections import deque

def BFS(start):
    INF = 10**8
    distance = [INF]*(N+1)
    distance[start] = 0
    que = deque()
    que.append(start)
    while que:
        current = que.popleft()
        for nxt in edges[current]:
            if distance[nxt] >= distance[current]+1:
                distance[nxt] = distance[current]+1
                que.append(nxt)
    return distance
        
N, Q = map(int, input().split())
edges = [[] for i in range(N+1)]
for i in range(N-1):
    a, b = map(int, input().split())
    edges[a].append(b)
    edges[b].append(a)

LCA = LCA(N+1, edges, 1) #3番目はルート
for q in range(Q):
    s, t = map(int, input().split())
    dummy, depth = LCA.query(s, t)
    # 第2varがLCA共通祖先からのdepth
    distance_s = LCA.tour[LCA.in_time[s]][1]
    # これでaから1の距離、つまりルートからの距離
    distance_t = LCA.tour[LCA.in_time[t]][1]
    distance_st = distance_s + distance_t - depth*2
    if distance_st%2 == 1:
        print(0)
    else:
        distance_s_list = BFS(s)
        distance_t_list = BFS(t)
        count = 0
        for i in range(1, N+1):
            if distance_s_list[i] == distance_t_list[i]:
                count += 1
        print(count)
0