結果

問題 No.2337 Equidistant
ユーザー lam6er
提出日時 2025-03-31 17:50:15
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,373 bytes
コンパイル時間 325 ms
コンパイル使用メモリ 82,084 KB
実行使用メモリ 476,996 KB
最終ジャッジ日時 2025-03-31 17:51:32
合計ジャッジ時間 22,716 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other WA * 21 TLE * 1 -- * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import deque
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    idx = 0
    N, Q = int(input[idx]), int(input[idx+1])
    idx += 2
    edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = int(input[idx]), int(input[idx+1])
        edges[a].append(b)
        edges[b].append(a)
        idx +=2
    queries = []
    for _ in range(Q):
        s, t = int(input[idx]), int(input[idx+1])
        queries.append((s, t))
        idx +=2

    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    size = [1]*(N+1)
    visited = [False]*(N+1)
    q = deque([1])
    visited[1] = True
    parent[0][1] = -1
    adj = [[] for _ in range(N+1)]
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v] and v != parent[0][u]:
                visited[v] = True
                parent[0][v] = u
                depth[v] = depth[u] +1
                q.append(v)
                adj[u].append(v)
    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]

    def dfs(u, p):
        for v in adj[u]:
            if v != p:
                dfs(v, u)
                size[u] += size[v]
    dfs(1, -1)

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    def get_kth_ancestor(u, k):
        if k <0:
            return -1
        for i in range(LOG):
            if k & (1 << i):
                u = parent[i][u]
                if u == -1:
                    return -1
        return u

    def get_dist(u, v):
        a = lca(u, v)
        return depth[u] + depth[v] - 2 * depth[a]

    out = []
    for s, t in queries:
        a_node = lca(s, t)
        a = depth[s] - depth[a_node]
        b = depth[t] - depth[a_node]
        D = a + b
        if D %2 !=0:
            out.append('0')
            continue
        k = D//2
        if a >=k:
            m = get_kth_ancestor(s, k)
        else:
            rem = k -a
            m = get_kth_ancestor(t, (b - rem))
        l = get_kth_ancestor(s, k-1) if k !=0 else -1
        l_tmp = get_kth_ancestor(t, (a +b -k) -1) if (a +b -k) !=0 else -1
        if k ==0:
            l = -1
        s_part = l
        if k ==0:
            s_part = -1
        if s_part ==-1:
            s_size =0
        else:
            if parent[0][m] == s_part:
                s_size = N - size[m]
            else:
                s_size = size[s_part] if s_part in adj[m] else 0
        t_part = l_tmp
        k_t = (a + b -k)
        if k_t ==0:
            t_part = -1
        if t_part == -1:
            t_size =0
        else:
            if parent[0][m] == t_part:
                t_size = N - size[m]
            else:
                t_size = size[t_part] if t_part in adj[m] else 0
        ans = size[m] - s_size - t_size
        out.append(str(ans))
    print('\n'.join(out))

if __name__ == '__main__':
    main()
0