結果

問題 No.898 tri-βutree
ユーザー yuly3yuly3
提出日時 2020-07-20 19:05:18
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 920 ms / 4,000 ms
コード長 2,631 bytes
コンパイル時間 198 ms
コンパイル使用メモリ 82,560 KB
実行使用メモリ 132,992 KB
最終ジャッジ日時 2024-04-26 09:29:20
合計ジャッジ時間 16,581 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 297 ms
125,952 KB
testcase_01 AC 42 ms
54,144 KB
testcase_02 AC 53 ms
62,720 KB
testcase_03 AC 53 ms
61,824 KB
testcase_04 AC 53 ms
61,952 KB
testcase_05 AC 54 ms
62,208 KB
testcase_06 AC 54 ms
62,080 KB
testcase_07 AC 919 ms
132,480 KB
testcase_08 AC 887 ms
131,712 KB
testcase_09 AC 855 ms
132,480 KB
testcase_10 AC 846 ms
132,736 KB
testcase_11 AC 866 ms
132,864 KB
testcase_12 AC 918 ms
132,352 KB
testcase_13 AC 863 ms
132,480 KB
testcase_14 AC 920 ms
132,352 KB
testcase_15 AC 873 ms
132,736 KB
testcase_16 AC 878 ms
131,840 KB
testcase_17 AC 850 ms
132,992 KB
testcase_18 AC 864 ms
132,096 KB
testcase_19 AC 892 ms
132,480 KB
testcase_20 AC 893 ms
132,352 KB
testcase_21 AC 906 ms
132,736 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

sys.setrecursionlimit(10 ** 7)
rl = sys.stdin.buffer.readline


class LowestCommonAncestor:
    def __init__(self, tree, root):
        self.n = len(tree)
        self.depth = [0] * self.n
        self.log_size = self.n.bit_length()
        self.parent = [[-1] * self.n for _ in range(self.log_size)]
        
        q = deque([(root, -1, 0)])
        while q:
            v, par, dist = q.pop()
            self.parent[0][v] = par
            self.depth[v] = dist
            for child in tree[v]:
                if child != par:
                    self.depth[child] = dist + 1
                    q.append((child, v, dist + 1))
        
        for k in range(1, self.log_size):
            for v in range(self.n):
                self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]]
    
    def query(self, u, v):
        if self.depth[v] < self.depth[u]:
            u, v = v, u
        for k in range(self.log_size):
            if self.depth[v] - self.depth[u] >> k & 1:
                v = self.parent[k][v]
        if u == v:
            return u
        
        for k in reversed(range(self.log_size)):
            if self.parent[k][u] != self.parent[k][v]:
                u = self.parent[k][u]
                v = self.parent[k][v]
        return self.parent[0][v]
    
    def get_dist(self, u, v):
        ancestor = self.query(u, v)
        return self.depth[u] - self.depth[ancestor] + self.depth[v] - self.depth[ancestor]


def solve():
    N = int(rl())
    graph = [[] for _ in range(N)]
    tree = [[] for _ in range(N)]
    for _ in range(N - 1):
        u, v, w = map(int, rl().split())
        graph[u].append((v, w))
        graph[v].append((u, w))
        tree[u].append(v)
        tree[v].append(u)
    
    INF = 10 ** 18
    costs = [INF] * N
    costs[0] = 0
    que = [(0, 0)]
    while que:
        cur, cost = que.pop()
        for child, w in graph[cur]:
            ncost = cost + w
            if costs[child] <= ncost:
                continue
            costs[child] = ncost
            que.append((child, ncost))
    
    lca = LowestCommonAncestor(tree, 0)
    ans = []
    Q = int(rl())
    for i in range(Q):
        x, y, z = map(int, rl().split())
        ancestor = lca.query(x, y)
        tmp = costs[x] + costs[y] - 2 * costs[ancestor]
        ancestor = lca.query(y, z)
        tmp += costs[y] + costs[z] - 2 * costs[ancestor]
        ancestor = lca.query(z, x)
        tmp += costs[z] + costs[x] - 2 * costs[ancestor]
        tmp //= 2
        ans.append(tmp)
    print(*ans, sep='\n')


if __name__ == '__main__':
    solve()
0