結果
| 問題 |
No.898 tri-βutree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2022-08-01 23:11:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,935 bytes |
| コンパイル時間 | 372 ms |
| コンパイル使用メモリ | 82,464 KB |
| 実行使用メモリ | 135,508 KB |
| 最終ジャッジ日時 | 2024-07-22 14:02:25 |
| 合計ジャッジ時間 | 20,902 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 1 WA * 20 |
ソースコード
#最小共通祖先
#ダブリング
class LcaDoubling:
"""
links[v] = { (u, w), (u, w), ... } (u:隣接頂点, w:辺の重み)
というグラフ情報から、ダブリングによるLCAを構築。
任意の2頂点のLCAおよび距離を取得できるようにする
"""
def __init__(self, n, links, root=0):
self.depths = [-1] * n
self.distances = [-1] * n
prev_ancestors = self._init_dfs(n, links, root)
self.ancestors = [prev_ancestors]
max_depth = max(self.depths)
d = 1
while d < max_depth:
next_ancestors = [prev_ancestors[p] for p in prev_ancestors]
self.ancestors.append(next_ancestors)
d <<= 1
prev_ancestors = next_ancestors
def _init_dfs(self, n, links, root):
q = [(root, -1, 0, 0)]
direct_ancestors = [-1] * (n + 1) # 頂点数より1個長くし、存在しないことを-1で表す。末尾(-1)要素は常に-1
while q:
v, p, dep, dist = q.pop()
direct_ancestors[v] = p
self.depths[v] = dep
self.distances[v] = dist
q.extend((u, v, dep + 1, dist + w) for u, w in links[v] if u != p)
return direct_ancestors
def get_lca(self, u, v):
du, dv = self.depths[u], self.depths[v]
if du > dv:
u, v = v, u
du, dv = dv, du
tu = u
tv = self.upstream(v, dv - du)
if u == tv:
return u
for k in range(du.bit_length() - 1, -1, -1):
mu = self.ancestors[k][tu]
mv = self.ancestors[k][tv]
if mu != mv:
tu = mu
tv = mv
lca = self.ancestors[0][tu]
assert lca == self.ancestors[0][tv]
return lca
def get_distance(self, u, v):
lca = self.get_lca(u, v)
return self.distances[u] + self.distances[v] - 2 * self.distances[lca]
def upstream(self, v, k):
i = 0
while k:
if k & 1:
v = self.ancestors[i][v]
k >>= 1
i += 1
return v
N = int(input())
links = [[] for i in range(N)]
for i in range(N-1):
u,v,w = map(int,input().split())
links[u].append((v,w))
links[v].append((u,w))
LCA = LcaDoubling(N, links)
Q = int(input())
ans = []
for i in range(Q):
x,y,z = map(int,input().split())
a = LCA.get_lca(x, y)
b = LCA.get_lca(y, z)
c = LCA.get_lca(z, x)
if a == b == c:
dist = LCA.get_distance(0,x)+LCA.get_distance(0,y)+LCA.get_distance(0,z)-2*LCA.get_distance(0,a)
else:
l = [a,b,c]
l.sort()
r2 = l[1]
for i in range(3):
if l[i] != r2:
r1 = l[i]
dist = LCA.get_distance(0,x)+LCA.get_distance(0,y)+LCA.get_distance(0,z)-LCA.get_distance(0,r2)-LCA.get_distance(0,r1)
ans.append(dist)
print(*ans,sep='\n')