結果

問題 No.2337 Equidistant
ユーザー lam6er
提出日時 2025-04-16 01:04:05
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,347 bytes
コンパイル時間 459 ms
コンパイル使用メモリ 81,804 KB
実行使用メモリ 199,092 KB
最終ジャッジ日時 2025-04-16 01:06:17
合計ジャッジ時間 19,891 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other WA * 21 TLE * 1 -- * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, Q = int(input[ptr]), int(input[ptr+1])
    ptr +=2

    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(input[ptr])
        b = int(input[ptr+1])
        adj[a].append(b)
        adj[b].append(a)
        ptr +=2

    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    in_time = [0]*(N+1)
    out_time = [0]*(N+1)
    size = [1]*(N+1)
    time = 0

    stack = [(1, -1, False)]
    while stack:
        u, p, visited = stack.pop()
        if visited:
            out_time[u] = time
            time +=1
            for v in adj[u]:
                if v != p:
                    size[u] += size[v]
            continue
        in_time[u] = time
        time +=1
        parent[0][u] = p
        depth[u] = depth[p] +1 if p != -1 else 0
        stack.append((u, p, True))
        for v in adj[u]:
            if v != p:
                stack.append((v, u, False))

    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
            else:
                parent[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if parent[k][u] != -1 and depth[parent[k][u]] >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    def lift(u, k):
        if k ==0:
            return u
        for i in range(LOG):
            if k & (1 << i):
                u = parent[i][u]
                if u == -1:
                    break
        return u

    for _ in range(Q):
        S = int(input[ptr])
        T = int(input[ptr+1])
        ptr +=2

        l = lca(S, T)
        a = depth[S] - depth[l]
        b = depth[T] - depth[l]
        D = a + b
        if D %2 !=0:
            print(0)
            continue
        mid = D//2
        if mid <= a:
            M = lift(S, mid)
            target = l
        else:
            mid_t = mid - a
            M = lift(T, b - mid_t)
            target = T

        if M == l:
            u_child = None
            for c in adj[M]:
                if c == parent[0][M]:
                    continue
                if in_time[c] <= in_time[S] <= out_time[c]:
                    u_child = c
                    break
            v_child = None
            for c in adj[M]:
                if c == parent[0][M]:
                    continue
                if in_time[c] <= in_time[T] <= out_time[c]:
                    v_child = c
                    break
            ans = size[M] - size[u_child] - size[v_child]
        else:
            v_child = None
            for c in adj[M]:
                if c == parent[0][M]:
                    continue
                if in_time[c] <= in_time[target] <= out_time[c]:
                    v_child = c
                    break
            ans = size[M] - size[v_child] if v_child is not None else 0
        print(ans)

if __name__ == '__main__':
    main()
0