結果

問題 No.2337 Equidistant
ユーザー lam6er
提出日時 2025-03-31 17:46:35
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,527 bytes
コンパイル時間 238 ms
コンパイル使用メモリ 82,964 KB
実行使用メモリ 268,192 KB
最終ジャッジ日時 2025-03-31 17:48:07
合計ジャッジ時間 38,866 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 21 TLE * 1 -- * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
input = sys.stdin.read
sys.setrecursionlimit(1 << 25)

def main():
    data = input().split()
    idx = 0
    N, Q = int(data[idx]), int(data[idx+1])
    idx += 2
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = int(data[idx]), int(data[idx+1])
        adj[a].append(b)
        adj[b].append(a)
        idx += 2
    
    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    size = [1]*(N+1)
    
    stack = [(1, -1, 0, False)]
    while stack:
        u, p, d, visited = stack.pop()
        if visited:
            for v in adj[u]:
                if v != p:
                    size[u] += size[v]
            continue
        parent[0][u] = p
        depth[u] = d
        stack.append((u, p, d, True))
        for v in adj[u]:
            if v != p:
                stack.append((v, u, d+1, False))
    
    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
    
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if parent[k][u] != -1 and depth[parent[k][u]] >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]
    
    for u in range(1, N+1):
        adj[u].sort()
    
    direction = [{} for _ in range(N+1)]
    for u in range(1, N+1):
        p = parent[0][u]
        for v in adj[u]:
            if v == p:
                sz = N - size[u]
            else:
                sz = size[v]
            direction[u][v] = sz
    
    out = []
    for _ in range(Q):
        S = int(data[idx])
        T = int(data[idx+1])
        idx += 2
        
        l = lca(S, T)
        a = depth[S] - depth[l]
        b = depth[T] - depth[l]
        D = a + b
        
        if D % 2 != 0:
            out.append('0')
            continue
        
        mid = D // 2
        if mid <= a:
            node = S
            up = mid
            for k in reversed(range(LOG)):
                if up >= (1 << k):
                    node = parent[k][node]
                    up -= (1 << k)
            M = node
        else:
            need = (mid - a)
            node = T
            up = b - need
            if up < 0:
                pass
            else:
                for k in reversed(range(LOG)):
                    if up >= (1 << k):
                        node = parent[k][node]
                        up -= (1 << k)
            M = node
        
        d_S_M = depth[S] + depth[M] - 2 * depth[lca(S, M)]
        d_T_M = depth[T] + depth[M] - 2 * depth[lca(T, M)]
        
        v1 = None
        v2 = None
        for v in adj[M]:
            lca_s_v = lca(S, v)
            d_S_v = depth[S] + depth[v] - 2 * depth[lca_s_v]
            lca_t_v = lca(T, v)
            d_T_v = depth[T] + depth[v] - 2 * depth[lca_t_v]
            if d_S_v == d_S_M - 1 and d_T_v == d_T_M + 1:
                v1 = v
            elif d_T_v == d_T_M - 1 and d_S_v == d_S_M + 1:
                v2 = v
        
        total = 1
        for v in adj[M]:
            if v != v1 and v != v2:
                total += direction[M][v]
        out.append(str(total))
    
    print('\n'.join(out))

if __name__ == '__main__':
    main()
0