結果

問題 No.2337 Equidistant
ユーザー FromBooskaFromBooska
提出日時 2023-10-11 15:02:45
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,922 bytes
コンパイル時間 538 ms
コンパイル使用メモリ 81,980 KB
実行使用メモリ 172,216 KB
最終ジャッジ日時 2024-09-13 13:55:03
合計ジャッジ時間 36,912 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 43 ms
55,584 KB
testcase_01 WA -
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 AC 43 ms
55,660 KB
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 AC 1,441 ms
150,316 KB
testcase_22 WA -
testcase_23 WA -
testcase_24 AC 1,680 ms
150,360 KB
testcase_25 WA -
testcase_26 AC 1,976 ms
150,092 KB
testcase_27 WA -
testcase_28 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

# ルート決めてLCA
# 2頂点のdepthが同じなら、LCAおよびその親、その先すべて
# 2頂点のdepthが異なり、そのdepth diffが奇数なら0
# 偶数なら1でいいか

## library of LCA by class
## index start from 0

import sys
sys.setrecursionlimit(10**7)
from collections import deque

class LCA:
    def __init__(self,n):
        self.size = n
        self.bitlen = n.bit_length()
        self.ancestor = [[0]*self.size for i in range(self.bitlen)]
        self.depth = [-1]*self.size
        self.dis = [-1]*self.size

    ## using [log_n][n] [n][log_n]
    ## [log_n][n] is tend to faster than [n][log_n]
    ## get parent by bfs is probably faster than dfs
    def make(self,root):
        self.depth[root] = 0
        self.dis[root] = 0
        q = deque([root])
        while q:
            now = q.popleft()
            for nex in edges[now]:
                if self.depth[nex]>= 0:
                    continue
                self.depth[nex] = self.depth[now]+1
                self.dis[nex] = self.dis[now]+1
                self.ancestor[0][nex] = now
                q.append(nex)
        for i in range(1,self.bitlen):
            for j in range(self.size):
                if self.ancestor[i-1][j] > 0:
                    self.ancestor[i][j] = self.ancestor[i-1][self.ancestor[i-1][j]]
    
    def lca(self,x,y):
        dx = self.depth[x]
        dy = self.depth[y]
        if dx < dy:
            x,y = y,x
            dx,dy = dy,dx
        dif = dx-dy
        while dif:
            s = dif & (-dif)
            x = self.ancestor[s.bit_length()-1][x]
            dif -= s
        while x != y:
            j = 0
            while self.ancestor[j][x] != self.ancestor[j][y]:
                j += 1
            if j == 0:
                return self.ancestor[0][x]
            x = self.ancestor[j-1][x]
            y = self.ancestor[j-1][y]
        return x

    def par(self,x,dep): #親parent
        now = x
        for i in range(self.bitlen)[::-1]:
            if 1 << i <= dep:
                now = self.ancestor[i][now]
                dep -= 1<<i
        return now

N, Q = map(int, input().split())
edges = [[] for i in range(N+1)]
for i in range(N-1):
    a, b = map(int, input().split())
    edges[a].append(b)
    edges[b].append(a)
    
lca = LCA(N+1) # 頂点数
lca.make(1) # ルート

# 部分木、子の数を数える
# dfsだとMLE, python3だとTLE
# なのでこの方のque方法にする

root = 1
child = [0]*(N+1)
visited = [0]*(N+1)
que = [root]
visited[root] = 1
topological = []
parent = [-1]*(N+1)
while que:
    current = que.pop()
    topological.append(current)
    for nxt in edges[current]:
        if visited[nxt] == 0:
            parent[nxt] = current
            visited[nxt] = 1
            que.append(nxt)

for current in topological[::-1]:
    count = 1
    for nxt in edges[current]:
        if nxt != parent[current]:
            count += child[nxt]
    child[current] = count
    
#print(child)

for q in range(Q):
    s, t = map(int, input().split())
    s_depth = lca.depth[s]
    t_depth = lca.depth[t]
    #print('s_depth', s_depth, 't_depth', t_depth)
    if s_depth == t_depth:
        lowest_common_ancestor = lca.lca(s, t)
        ans = N - child[lowest_common_ancestor] + 1
    elif (s_depth-t_depth)%2 == 0:
        if s_depth > t_depth:
            deeper = s
            shalower = t
        else:
            deeper = t
            shalower = s
        mid = lca.par(deeper, (s_depth+t_depth)//2)
        if mid == shalower:
            mid = lca.par(deeper, (s_depth+t_depth)//4)
        mid_depth_diff = lca.depth[deeper]-lca.depth[mid]
        ans = 1
        for nxt in edges[mid]:
            if nxt != lca.par(deeper, mid_depth_diff-1) and nxt != lca.par(deeper, mid_depth_diff+1):
                ans += child[nxt]
        #print('mid', mid)
    elif (s_depth-t_depth)%2 == 1:
        ans = 0
    print(ans)
0