結果

問題 No.2337 Equidistant
ユーザー lam6er
提出日時 2025-04-15 21:46:09
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,163 bytes
コンパイル時間 454 ms
コンパイル使用メモリ 81,536 KB
実行使用メモリ 268,640 KB
最終ジャッジ日時 2025-04-15 21:47:31
合計ジャッジ時間 25,930 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 6 WA * 22
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    idx = 0
    N, Q = int(input[idx]), int(input[idx+1])
    idx += 2
    
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(input[idx])
        b = int(input[idx+1])
        adj[a].append(b)
        adj[b].append(a)
        idx += 2
    
    parent = [0] * (N + 1)
    depth = [0] * (N + 1)
    visited = [False] * (N + 1)
    root = 1
    parent[root] = 0
    q = deque([root])
    visited[root] = True
    
    while q:
        u = q.popleft()
        for v in adj[u]:
            if not visited[v] and v != parent[u]:
                parent[v] = u
                depth[v] = depth[u] + 1
                visited[v] = True
                q.append(v)
    
    size = [1] * (N + 1)
    stack = [(root, False)]
    while stack:
        node, processed = stack.pop()
        if processed:
            for v in adj[node]:
                if v != parent[node]:
                    size[node] += size[v]
        else:
            stack.append((node, True))
            for v in adj[node]:
                if v != parent[node]:
                    stack.append((v, False))
    
    LOG = 20
    up = [[0] * (N + 1) for _ in range(LOG)]
    up[0] = parent
    for k in range(1, LOG):
        for v in range(1, N + 1):
            up[k][v] = up[k-1][up[k-1][v]]
    
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = up[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if up[k][u] != up[k][v]:
                u = up[k][u]
                v = up[k][v]
        return parent[u]
    
    def get_kth_ancestor(u, k):
        current = u
        for j in range(LOG):
            if (k >> j) & 1:
                current = up[j][current]
                if current == 0:
                    break
        return current
    
    dir_size = [{} for _ in range(N+1)]
    for u in range(1, N+1):
        for v in adj[u]:
            if v == parent[u]:
                dir_size[u][v] = N - size[u]
            else:
                dir_size[u][v] = size[v]
    
    output = []
    for _ in range(Q):
        S = int(input[idx])
        T = int(input[idx+1])
        idx += 2
        
        L = lca(S, T)
        a = depth[S] - depth[L]
        b = depth[T] - depth[L]
        D = a + b
        
        if D % 2 != 0:
            output.append(0)
            continue
        
        k = D // 2
        if k <= a:
            M = get_kth_ancestor(S, k)
        else:
            M = get_kth_ancestor(T, (b - (k - a)))
        
        if k == 0:
            output.append(0)
            continue
        
        prev = get_kth_ancestor(S, k-1)
        next_node = get_kth_ancestor(T, k-1)
        
        ds_prev = dir_size[M].get(prev, 0)
        ds_next = dir_size[M].get(next_node, 0)
        
        ans = N - ds_prev - ds_next
        output.append(ans)
    
    print('\n'.join(map(str, output)))

if __name__ == '__main__':
    main()
0