結果

問題 No.898 tri-βutree
ユーザー dice4084dice4084
提出日時 2023-08-04 01:14:34
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,213 ms / 4,000 ms
コード長 2,657 bytes
コンパイル時間 164 ms
コンパイル使用メモリ 82,532 KB
実行使用メモリ 113,680 KB
最終ジャッジ日時 2024-04-21 23:49:48
合計ジャッジ時間 22,744 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 442 ms
111,676 KB
testcase_01 AC 41 ms
55,416 KB
testcase_02 AC 56 ms
65,972 KB
testcase_03 AC 56 ms
66,164 KB
testcase_04 AC 59 ms
65,284 KB
testcase_05 AC 58 ms
65,348 KB
testcase_06 AC 56 ms
66,080 KB
testcase_07 AC 1,213 ms
113,680 KB
testcase_08 AC 1,149 ms
112,812 KB
testcase_09 AC 1,157 ms
113,172 KB
testcase_10 AC 1,173 ms
112,828 KB
testcase_11 AC 1,209 ms
113,676 KB
testcase_12 AC 1,205 ms
113,024 KB
testcase_13 AC 1,188 ms
112,860 KB
testcase_14 AC 1,204 ms
112,864 KB
testcase_15 AC 1,149 ms
112,796 KB
testcase_16 AC 1,191 ms
112,824 KB
testcase_17 AC 1,193 ms
113,092 KB
testcase_18 AC 1,163 ms
113,188 KB
testcase_19 AC 1,162 ms
112,380 KB
testcase_20 AC 1,169 ms
113,020 KB
testcase_21 AC 1,175 ms
112,808 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

input = sys.stdin.readline


class Lowest_Common_Ancestor:
    def __init__(self, T, root=0):
        self.dist, self.parent, self.weight = self._preprocess(T, root)

    def _preprocess(self, T, root):
        """
        前処理 O(N logN)
        """
        from collections import deque

        n = len(T)
        k = 1
        while (1 << k) < n:
            k += 1

        q = deque([root])
        dist = [-1] * n
        parent = [[-1] * n for _ in range(k + 1)]
        weight = [0] * n
        dist[root], parent[0][root], weight[root] = 0, root, 0
        while q:
            v = q.popleft()
            for w, nv in T[v]:
                if nv == parent[0][v]:
                    continue
                dist[nv], parent[0][nv] = dist[v] + 1, v
                weight[nv] = weight[v] + w
                q.append(nv)

        for i in range(k - 1):
            for j in range(n):
                parent[i + 1][j] = parent[i][parent[i][j]]
        return dist, parent, weight

    def query(self, u, v):
        """
        u, v のLCAを取得
        O(logN)
        """
        if self.dist[u] < self.dist[v]:
            u, v = v, u
        k = len(self.parent)
        for i in range(k):
            if ((self.dist[u] - self.dist[v]) >> i) & 1:
                u = self.parent[i][u]

        if u == v:
            return u
        for i in range(k - 1, -1, -1):
            if self.parent[i][u] != self.parent[i][v]:
                u = self.parent[i][u]
                v = self.parent[i][v]
        return self.parent[0][u]

    def get_2v_dist(self, u, v):
        """
        u-v間の距離を取得
        O(logN)
        """
        return self.dist[u] + self.dist[v] - 2 * self.dist[self.query(u, v)]

    def get_2v_weight(self, u, v):
        """
        u-v間の重みを取得
        O(logN)
        """
        return self.weight[u] + self.weight[v] - 2 * self.weight[self.query(u, v)]

    def is_on_path(self, u, v, a):
        """
        path u-v 上に a が存在するかどうか判定
        O(logN)
        """
        return get_2v_dist(u, a) + get_2v_dist(v, a) == get_2v_dist(u, v)


n = int(input())
edge = [[] for _ in range(n)]
for _ in range(n - 1):
    a, b, w = map(int, input().split())
    edge[a].append((w, b))
    edge[b].append((w, a))

lca = Lowest_Common_Ancestor(edge)
q = int(input())
for _ in range(q):
    ans = 10**18
    x, y, z = map(int, input().split())
    for _ in range(3):
        tmp = lca.get_2v_weight(x, y)
        w = lca.query(x, y)
        tmp += lca.get_2v_weight(z, w)
        if tmp < ans:
            ans = tmp
        x, y, z = y, z, x

    print(ans)
0