結果

問題 No.1094 木登り / Climbing tree
ユーザー Theta
提出日時 2024-03-29 11:27:12
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,079 bytes
コンパイル時間 699 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 241,796 KB
最終ジャッジ日時 2024-09-30 14:59:53
合計ジャッジ時間 51,669 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 11 TLE * 15
権限があれば一括ダウンロードができます

ソースコード

diff #

from math import inf, isinf
from heapq import heappop, heappush
from functools import cache
from collections import defaultdict, deque
from collections import deque
from itertools import count
from math import inf
import sys
from typing import List


def printe(*args, end="\n", **kwargs):
    print(*args, end=end, file=sys.stderr, **kwargs)


def doubling_on_tree(parent_list: List[int]) -> List[List[int]]:
    root_idx = -1
    for idx, p in enumerate(parent_list):
        if idx == p:
            root_idx = idx
            break
    else:
        raise ValueError
    double_list = [[] for _ in range(len(parent_list))]
    for idx, p in enumerate(parent_list):
        double_list[idx].append(p)
    unroot = set(range(len(parent_list)))
    for k in count(1):
        for idx in range(len(parent_list)):
            double_list[idx].append(
                double_list[double_list[idx][k - 1]][k - 1])
            if double_list[idx][-1] == root_idx:
                unroot.discard(idx)
        if not unroot:
            break
    return double_list


def calc_lca(doubling: List[List[int]],
             depth: List[int],
             idx1: int,
             idx2: int) -> int:

    if depth[idx1] < depth[idx2]:
        idx1, idx2 = idx2, idx1

    depth_diff = depth[idx1] - depth[idx2]
    shift_ctr = 0
    while depth_diff > 0:
        if depth_diff & 1:
            idx1 = doubling[idx1][shift_ctr]
        depth_diff >>= 1
        shift_ctr += 1

    if idx1 == idx2:
        return idx1

    k = len(doubling[idx1]) - 1
    while k >= 0:
        if doubling[idx1][k] != doubling[idx2][k]:
            idx1 = doubling[idx1][k]
            idx2 = doubling[idx2][k]
        k -= 1

    assert doubling[idx1][0] == doubling[idx2][0]
    return doubling[idx1][0]


def dijkstra(graph: list[dict[int, int]], start: int,
             dist_max=int(1e9)) -> list[int]:
    distance = [dist_max] * len(graph)
    queue = [(0, start)]
    distance[start] = 0
    while queue:
        c_d, c_n = heappop(queue)
        if distance[c_n] < c_d:
            continue
        for n_n, n_d in graph[c_n].items():
            if distance[n_n] > c_d + n_d:
                distance[n_n] = c_d + n_d
                heappush(queue, (distance[n_n], n_n))
    return distance


def main():
    N = int(input())
    graph = [{} for _ in range(N)]

    for _ in range(N - 1):
        a, b, c = map(int, input().split())
        graph[a - 1][b - 1] = c
        graph[b - 1][a - 1] = c
    parents = [0] * N
    depth = [inf] * N
    depth[0] = 0
    queue = deque((0, ))
    while queue:
        c = queue.popleft()
        for n in graph[c]:
            if depth[n] > depth[c]:
                depth[n] = depth[c] + 1
                parents[n] = c
                queue.append(n)

    dists = dijkstra(graph, 0)
    doubles = doubling_on_tree(parents)

    for _ in range(int(input())):
        s, t = map(lambda n: int(n) - 1, input().split())
        lca = calc_lca(doubles, depth, s, t)
        print(dists[s] + dists[t] - 2 * dists[lca])


if __name__ == "__main__":
    main()
0