結果
| 問題 |
No.898 tri-βutree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-07-17 18:00:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,152 bytes |
| コンパイル時間 | 182 ms |
| コンパイル使用メモリ | 82,508 KB |
| 実行使用メモリ | 171,196 KB |
| 最終ジャッジ日時 | 2024-07-07 05:18:19 |
| 合計ジャッジ時間 | 24,156 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 10 WA * 11 |
ソースコード
import sys
input = lambda : sys.stdin.readline().rstrip()
sys.setrecursionlimit(2*10**5+10)
write = lambda x: sys.stdout.write(x+"\n")
debug = lambda x: sys.stderr.write(x+"\n")
writef = lambda x: print("{:.12f}".format(x))
# 木の読み込み・重みあり
n = int(input())
ns = [[] for _ in range(n)]
vs = [0]*n
for _ in range(n-1):
u,v,c = map(int, input().split())
u -= 1
v -= 1
ns[u].append((c,v))
ns[v].append((c,u))
"""木におけるダブリングdouble
祖先と何かを同時に求める
"""
# 深さ
def cdepth(ns, root=0):
# rootを根としたときの深さ
ps = [None] * n
ps[root] = -1
q = [0]
while q:
u = q.pop()
for c,v in ns[u]:
if ps[v] is None:
ps[v] = u
vs[v] = c
q.append(v)
# psを元から持っている場合、引数のnsをpsにしてこの下だけで良い
depth = [None] * len(ps)
ns = [[] for _ in range(len(ps))]
for i,p in enumerate(ps):
ns[p].append(i)
depth[root] = 0
q = [root]
while q:
u = q.pop()
for v in ns[u]:
if depth[v] is None:
depth[v] = depth[u] + 1
q.append(v)
return depth, ps
# ダブリング
def double(ps, vs=None):
# global: n=psのサイズ
prev = [[None]*n for _ in range(k)] # prev[i][j]: jから2^i個上の上司
vals = [[None]*n for _ in range(k)] # vals[i][j]: jから2^i個上の上司までの間の枝重みのmax
for j in range(n):
prev[0][j] = ps[j]
vals[0][j] = vs[j]
for i in range(1,k):
for j in range(n):
p = prev[i-1][j]
if p>=0:
prev[i][j] = prev[i-1][p]
vals[i][j] = op(vals[i-1][j], vals[i-1][p])
else:
prev[i][j] = p
vals[i][j] = vals[i-1][j]
return prev, vals
# k: 必要桁数を定める必要アリ
def cprev(u,i):
"""uからi個上の頂点を返す
"""
vv = ninf
for j in range(k):
if i>>j&1:
vv = op(vv, vals[j][u])
u = prev[j][u]
return u, vv
# k: 必要桁数を定める必要アリ
def lca(u,v):
if depth[u]<depth[v]:
v,val = cprev(v, depth[v]-depth[u])
else:
u,val = cprev(u, depth[u]-depth[v])
if u==v:
return u, val
# 上のvalをそのまま↓に持ち越すので注意
for i in range(k-1, -1, -1):
if prev[i][u]!=prev[i][v]:
val = op(vals[i][u], vals[i][v], val)
u = prev[i][u]
v = prev[i][v]
return prev[0][u], op(val, vals[0][u], vals[0][v]) # このあたり注意
depth, ps = cdepth(ns,0)
k = 0
n = len(ps)
while pow(2,k)<n:
k += 1
###
"""このあたりを書き換える
vs: 各ノードについての値
"""
def op(*args):
return sum(args)
ninf = 0
###
prev,vals = double(ps,vs)
q = int(input())
ans = []
for i in range(q):
x,y,z = map(lambda i: int(i)-1, input().split())
l, v1 = lca(x,y)
l2, v2 = lca(y,z)
l3, v3 = lca(x,z)
res = (v1+v2+v3)//2
ans.append(res)
write("\n".join(map(str, ans)))