結果
| 問題 |
No.898 tri-βutree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-08-04 01:14:34 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,053 ms / 4,000 ms |
| コード長 | 2,657 bytes |
| コンパイル時間 | 985 ms |
| コンパイル使用メモリ | 82,096 KB |
| 実行使用メモリ | 113,672 KB |
| 最終ジャッジ日時 | 2024-10-13 22:19:45 |
| 合計ジャッジ時間 | 20,716 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 21 |
ソースコード
import sys
input = sys.stdin.readline
class Lowest_Common_Ancestor:
def __init__(self, T, root=0):
self.dist, self.parent, self.weight = self._preprocess(T, root)
def _preprocess(self, T, root):
"""
前処理 O(N logN)
"""
from collections import deque
n = len(T)
k = 1
while (1 << k) < n:
k += 1
q = deque([root])
dist = [-1] * n
parent = [[-1] * n for _ in range(k + 1)]
weight = [0] * n
dist[root], parent[0][root], weight[root] = 0, root, 0
while q:
v = q.popleft()
for w, nv in T[v]:
if nv == parent[0][v]:
continue
dist[nv], parent[0][nv] = dist[v] + 1, v
weight[nv] = weight[v] + w
q.append(nv)
for i in range(k - 1):
for j in range(n):
parent[i + 1][j] = parent[i][parent[i][j]]
return dist, parent, weight
def query(self, u, v):
"""
u, v のLCAを取得
O(logN)
"""
if self.dist[u] < self.dist[v]:
u, v = v, u
k = len(self.parent)
for i in range(k):
if ((self.dist[u] - self.dist[v]) >> i) & 1:
u = self.parent[i][u]
if u == v:
return u
for i in range(k - 1, -1, -1):
if self.parent[i][u] != self.parent[i][v]:
u = self.parent[i][u]
v = self.parent[i][v]
return self.parent[0][u]
def get_2v_dist(self, u, v):
"""
u-v間の距離を取得
O(logN)
"""
return self.dist[u] + self.dist[v] - 2 * self.dist[self.query(u, v)]
def get_2v_weight(self, u, v):
"""
u-v間の重みを取得
O(logN)
"""
return self.weight[u] + self.weight[v] - 2 * self.weight[self.query(u, v)]
def is_on_path(self, u, v, a):
"""
path u-v 上に a が存在するかどうか判定
O(logN)
"""
return get_2v_dist(u, a) + get_2v_dist(v, a) == get_2v_dist(u, v)
n = int(input())
edge = [[] for _ in range(n)]
for _ in range(n - 1):
a, b, w = map(int, input().split())
edge[a].append((w, b))
edge[b].append((w, a))
lca = Lowest_Common_Ancestor(edge)
q = int(input())
for _ in range(q):
ans = 10**18
x, y, z = map(int, input().split())
for _ in range(3):
tmp = lca.get_2v_weight(x, y)
w = lca.query(x, y)
tmp += lca.get_2v_weight(z, w)
if tmp < ans:
ans = tmp
x, y, z = y, z, x
print(ans)