結果

問題 No.898 tri-βutree
ユーザー shotoyooshotoyoo
提出日時 2021-07-17 18:00:02
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,152 bytes
コンパイル時間 687 ms
コンパイル使用メモリ 86,668 KB
実行使用メモリ 172,212 KB
最終ジャッジ日時 2023-09-21 11:19:44
合計ジャッジ時間 34,762 ms
ジャッジサーバーID
(参考情報)
judge14 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1,021 ms
172,212 KB
testcase_01 AC 75 ms
71,248 KB
testcase_02 WA -
testcase_03 WA -
testcase_04 AC 92 ms
75,840 KB
testcase_05 AC 89 ms
75,652 KB
testcase_06 WA -
testcase_07 WA -
testcase_08 AC 1,911 ms
164,732 KB
testcase_09 AC 1,931 ms
163,772 KB
testcase_10 WA -
testcase_11 WA -
testcase_12 AC 1,903 ms
163,720 KB
testcase_13 AC 1,920 ms
162,424 KB
testcase_14 AC 1,968 ms
161,012 KB
testcase_15 AC 1,905 ms
160,928 KB
testcase_16 WA -
testcase_17 WA -
testcase_18 AC 1,875 ms
161,400 KB
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = lambda : sys.stdin.readline().rstrip()

sys.setrecursionlimit(2*10**5+10)
write = lambda x: sys.stdout.write(x+"\n")
debug = lambda x: sys.stderr.write(x+"\n")
writef = lambda x: print("{:.12f}".format(x))

# 木の読み込み・重みあり
n = int(input())
ns = [[] for _ in range(n)]
vs = [0]*n
for _ in range(n-1):
    u,v,c = map(int, input().split())
    u -= 1
    v -= 1
    ns[u].append((c,v))
    ns[v].append((c,u))
    
"""木におけるダブリングdouble
祖先と何かを同時に求める
"""
# 深さ
def cdepth(ns, root=0):
    # rootを根としたときの深さ
    ps = [None] * n
    ps[root] = -1
    q = [0]
    while q:
        u = q.pop()
        for c,v in ns[u]:
            if ps[v] is None:
                ps[v] = u
                vs[v] = c
                q.append(v)
    # psを元から持っている場合、引数のnsをpsにしてこの下だけで良い
    depth = [None] * len(ps)
    ns = [[] for _ in range(len(ps))]
    for i,p in enumerate(ps):
        ns[p].append(i)
    depth[root] = 0
    q = [root]
    while q:
        u = q.pop()
        for v in ns[u]:
            if depth[v] is None:
                depth[v] = depth[u] + 1
                q.append(v)
    return depth, ps

# ダブリング
def double(ps, vs=None):
    # global: n=psのサイズ
    prev = [[None]*n for _ in range(k)] # prev[i][j]: jから2^i個上の上司
    vals = [[None]*n for _ in range(k)] # vals[i][j]: jから2^i個上の上司までの間の枝重みのmax
    for j in range(n):
        prev[0][j] = ps[j]
        vals[0][j] = vs[j]
    for i in range(1,k):
        for j in range(n):
            p = prev[i-1][j]
            if p>=0:
                prev[i][j] = prev[i-1][p]
                vals[i][j] = op(vals[i-1][j], vals[i-1][p])
            else:
                prev[i][j] = p
                vals[i][j] = vals[i-1][j]
    return prev, vals

# k: 必要桁数を定める必要アリ
def cprev(u,i):
    """uからi個上の頂点を返す
    """
    vv = ninf
    for j in range(k):
        if i>>j&1:
            vv = op(vv, vals[j][u])
            u = prev[j][u]
    return u, vv

# k: 必要桁数を定める必要アリ
def lca(u,v):
    if depth[u]<depth[v]:
        v,val = cprev(v, depth[v]-depth[u])
    else:
        u,val = cprev(u, depth[u]-depth[v])
    if u==v:
        return u, val
    # 上のvalをそのまま↓に持ち越すので注意
    for i in range(k-1, -1, -1):
        if prev[i][u]!=prev[i][v]:
            val = op(vals[i][u], vals[i][v], val)
            u = prev[i][u]
            v = prev[i][v]
    return prev[0][u], op(val, vals[0][u], vals[0][v]) # このあたり注意

depth, ps = cdepth(ns,0)
k = 0
n = len(ps)
while pow(2,k)<n:
    k += 1
###
"""このあたりを書き換える
vs: 各ノードについての値
"""
def op(*args):
    return sum(args)
ninf = 0
###

prev,vals = double(ps,vs)
q = int(input())
ans = []
for i in range(q):
    x,y,z = map(lambda i: int(i)-1, input().split())
    l, v1 = lca(x,y)
    l2, v2 = lca(y,z)
    l3, v3 = lca(x,z)
    res = (v1+v2+v3)//2
    ans.append(res)
write("\n".join(map(str, ans)))
0