結果

問題 No.2337 Equidistant
ユーザー lam6er
提出日時 2025-04-16 15:37:05
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,332 bytes
コンパイル時間 386 ms
コンパイル使用メモリ 82,092 KB
実行使用メモリ 217,332 KB
最終ジャッジ日時 2025-04-16 15:42:53
合計ジャッジ時間 23,949 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 6 WA * 22
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
input = sys.stdin.read
data = input().split()
idx = 0

def main():
    global idx
    N, Q = int(data[idx]), int(data[idx+1])
    idx += 2

    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(data[idx])
        b = int(data[idx+1])
        adj[a].append(b)
        adj[b].append(a)
        idx += 2

    queries = []
    for _ in range(Q):
        s = int(data[idx])
        t = int(data[idx+1])
        queries.append((s, t))
        idx += 2

    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    subtree_size = [1]*(N+1)

    stack = [(1, -1)]
    while stack:
        u, p = stack.pop()
        parent[0][u] = p
        for v in adj[u]:
            if v != p:
                depth[v] = depth[u] + 1
                stack.append((v, u))

    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
            else:
                parent[k][v] = -1

    stack = []
    visited = [False]*(N+1)
    stack.append((1, False))
    while stack:
        u, done = stack.pop()
        if done:
            for v in adj[u]:
                if v != parent[0][u]:
                    subtree_size[u] += subtree_size[v]
        else:
            stack.append((u, True))
            for v in adj[u]:
                if v != parent[0][u]:
                    stack.append((v, False))

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    def get_kth_ancestor(u, k):
        if k == 0:
            return u
        current = u
        for i in range(LOG):
            if (k >> i) & 1:
                if current == -1:
                    return -1
                current = parent[i][current]
        return current

    output = []
    for s, t in queries:
        if s == t:
            output.append("0")
            continue
        l = lca(s, t)
        a = depth[s] - depth[l]
        b = depth[t] - depth[l]
        L = a + b
        if L % 2 != 0:
            output.append("0")
            continue
        K = L // 2
        if K <= a:
            u = s
            m = K
        else:
            u = t
            m = b - (K - a)
        M = get_kth_ancestor(u, m)
        if M == -1:
            output.append("0")
            continue
        a_new = depth[s] - depth[M]
        b_new = depth[t] - depth[M]
        child_s = None
        if a_new > 0:
            child_s = get_kth_ancestor(s, a_new - 1)
        child_t = None
        if b_new > 0:
            child_t = get_kth_ancestor(t, b_new - 1)
        sum_sub = 0
        if child_s is not None and parent[0][child_s] == M:
            sum_sub += subtree_size[child_s]
        if child_t is not None and parent[0][child_t] == M:
            sum_sub += subtree_size[child_t]
        output.append(str(subtree_size[M] - sum_sub))

    print('\n'.join(output))

if __name__ == "__main__":
    main()
0