結果

問題 No.898 tri-βutree
ユーザー dice4084dice4084
提出日時 2023-08-04 01:14:34
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,053 ms / 4,000 ms
コード長 2,657 bytes
コンパイル時間 985 ms
コンパイル使用メモリ 82,096 KB
実行使用メモリ 113,672 KB
最終ジャッジ日時 2024-10-13 22:19:45
合計ジャッジ時間 20,716 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 401 ms
111,424 KB
testcase_01 AC 44 ms
54,144 KB
testcase_02 AC 62 ms
65,280 KB
testcase_03 AC 63 ms
65,536 KB
testcase_04 AC 61 ms
64,896 KB
testcase_05 AC 61 ms
65,024 KB
testcase_06 AC 59 ms
65,152 KB
testcase_07 AC 1,019 ms
113,584 KB
testcase_08 AC 993 ms
112,556 KB
testcase_09 AC 977 ms
113,040 KB
testcase_10 AC 999 ms
112,684 KB
testcase_11 AC 1,030 ms
113,672 KB
testcase_12 AC 1,053 ms
113,028 KB
testcase_13 AC 1,031 ms
112,868 KB
testcase_14 AC 1,037 ms
112,480 KB
testcase_15 AC 1,032 ms
112,412 KB
testcase_16 AC 1,022 ms
112,496 KB
testcase_17 AC 1,047 ms
113,024 KB
testcase_18 AC 1,027 ms
113,444 KB
testcase_19 AC 1,019 ms
112,476 KB
testcase_20 AC 993 ms
112,644 KB
testcase_21 AC 1,022 ms
112,660 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

input = sys.stdin.readline


class Lowest_Common_Ancestor:
    def __init__(self, T, root=0):
        self.dist, self.parent, self.weight = self._preprocess(T, root)

    def _preprocess(self, T, root):
        """
        前処理 O(N logN)
        """
        from collections import deque

        n = len(T)
        k = 1
        while (1 << k) < n:
            k += 1

        q = deque([root])
        dist = [-1] * n
        parent = [[-1] * n for _ in range(k + 1)]
        weight = [0] * n
        dist[root], parent[0][root], weight[root] = 0, root, 0
        while q:
            v = q.popleft()
            for w, nv in T[v]:
                if nv == parent[0][v]:
                    continue
                dist[nv], parent[0][nv] = dist[v] + 1, v
                weight[nv] = weight[v] + w
                q.append(nv)

        for i in range(k - 1):
            for j in range(n):
                parent[i + 1][j] = parent[i][parent[i][j]]
        return dist, parent, weight

    def query(self, u, v):
        """
        u, v のLCAを取得
        O(logN)
        """
        if self.dist[u] < self.dist[v]:
            u, v = v, u
        k = len(self.parent)
        for i in range(k):
            if ((self.dist[u] - self.dist[v]) >> i) & 1:
                u = self.parent[i][u]

        if u == v:
            return u
        for i in range(k - 1, -1, -1):
            if self.parent[i][u] != self.parent[i][v]:
                u = self.parent[i][u]
                v = self.parent[i][v]
        return self.parent[0][u]

    def get_2v_dist(self, u, v):
        """
        u-v間の距離を取得
        O(logN)
        """
        return self.dist[u] + self.dist[v] - 2 * self.dist[self.query(u, v)]

    def get_2v_weight(self, u, v):
        """
        u-v間の重みを取得
        O(logN)
        """
        return self.weight[u] + self.weight[v] - 2 * self.weight[self.query(u, v)]

    def is_on_path(self, u, v, a):
        """
        path u-v 上に a が存在するかどうか判定
        O(logN)
        """
        return get_2v_dist(u, a) + get_2v_dist(v, a) == get_2v_dist(u, v)


n = int(input())
edge = [[] for _ in range(n)]
for _ in range(n - 1):
    a, b, w = map(int, input().split())
    edge[a].append((w, b))
    edge[b].append((w, a))

lca = Lowest_Common_Ancestor(edge)
q = int(input())
for _ in range(q):
    ans = 10**18
    x, y, z = map(int, input().split())
    for _ in range(3):
        tmp = lca.get_2v_weight(x, y)
        w = lca.query(x, y)
        tmp += lca.get_2v_weight(z, w)
        if tmp < ans:
            ans = tmp
        x, y, z = y, z, x

    print(ans)
0