結果

問題 No.898 tri-βutree
ユーザー shotoyooshotoyoo
提出日時 2021-07-17 17:57:32
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,401 bytes
コンパイル時間 255 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 180,612 KB
最終ジャッジ日時 2024-07-07 05:15:21
合計ジャッジ時間 32,203 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 954 ms
172,928 KB
testcase_01 AC 37 ms
52,876 KB
testcase_02 WA -
testcase_03 RE -
testcase_04 AC 52 ms
65,328 KB
testcase_05 AC 50 ms
62,868 KB
testcase_06 WA -
testcase_07 RE -
testcase_08 AC 1,568 ms
153,612 KB
testcase_09 AC 2,668 ms
180,300 KB
testcase_10 RE -
testcase_11 RE -
testcase_12 AC 2,737 ms
176,384 KB
testcase_13 AC 1,604 ms
152,648 KB
testcase_14 AC 2,779 ms
180,612 KB
testcase_15 AC 2,658 ms
174,668 KB
testcase_16 RE -
testcase_17 RE -
testcase_18 AC 2,736 ms
176,116 KB
testcase_19 WA -
testcase_20 RE -
testcase_21 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = lambda : sys.stdin.readline().rstrip()

sys.setrecursionlimit(2*10**5+10)
write = lambda x: sys.stdout.write(x+"\n")
debug = lambda x: sys.stderr.write(x+"\n")
writef = lambda x: print("{:.12f}".format(x))

# 木の読み込み・重みあり
n = int(input())
ns = [[] for _ in range(n)]
vs = [0]*n
for _ in range(n-1):
    u,v,c = map(int, input().split())
    u -= 1
    v -= 1
    ns[u].append((c,v))
    ns[v].append((c,u))
    
"""木におけるダブリングdouble
祖先と何かを同時に求める
"""
# 深さ
def cdepth(ns, root=0):
    # rootを根としたときの深さ
    ps = [None] * n
    ps[root] = -1
    q = [0]
    while q:
        u = q.pop()
        for c,v in ns[u]:
            if ps[v] is None:
                ps[v] = u
                vs[v] = c
                q.append(v)
    # psを元から持っている場合、引数のnsをpsにしてこの下だけで良い
    depth = [None] * len(ps)
    ns = [[] for _ in range(len(ps))]
    for i,p in enumerate(ps):
        ns[p].append(i)
    depth[root] = 0
    q = [root]
    while q:
        u = q.pop()
        for v in ns[u]:
            if depth[v] is None:
                depth[v] = depth[u] + 1
                q.append(v)
    return depth, ps

# ダブリング
def double(ps, vs=None):
    # global: n=psのサイズ
    prev = [[None]*n for _ in range(k)] # prev[i][j]: jから2^i個上の上司
    vals = [[None]*n for _ in range(k)] # vals[i][j]: jから2^i個上の上司までの間の枝重みのmax
    for j in range(n):
        prev[0][j] = ps[j]
        vals[0][j] = vs[j]
    for i in range(1,k):
        for j in range(n):
            p = prev[i-1][j]
            if p>=0:
                prev[i][j] = prev[i-1][p]
                vals[i][j] = op(vals[i-1][j], vals[i-1][p])
            else:
                prev[i][j] = p
                vals[i][j] = vals[i-1][j]
    return prev, vals

# k: 必要桁数を定める必要アリ
def cprev(u,i):
    """uからi個上の頂点を返す
    """
    vv = ninf
    for j in range(k):
        if i>>j&1:
            vv = op(vv, vals[j][u])
            u = prev[j][u]
    return u, vv

# k: 必要桁数を定める必要アリ
def lca(u,v):
    if depth[u]<depth[v]:
        v,val = cprev(v, depth[v]-depth[u])
    else:
        u,val = cprev(u, depth[u]-depth[v])
    if u==v:
        return u, val
    # 上のvalをそのまま↓に持ち越すので注意
    for i in range(k-1, -1, -1):
        if prev[i][u]!=prev[i][v]:
            val = op(vals[i][u], vals[i][v], val)
            u = prev[i][u]
            v = prev[i][v]
    return prev[0][u], op(val, vals[0][u], vals[0][v]) # このあたり注意

depth, ps = cdepth(ns,0)
k = 0
n = len(ps)
while pow(2,k)<n:
    k += 1
###
"""このあたりを書き換える
vs: 各ノードについての値
"""
def op(*args):
    return sum(args)
ninf = 0
###

prev,vals = double(ps,vs)
q = int(input())
ans = []
for i in range(q):
    x,y,z = map(lambda i: int(i)-1, input().split())
    l, v1 = lca(x,y)
    l2, v2 = lca(l,z)
    if l!=l2:
        res = v1+v2
    else:
        l3, v3 = lca(x,z)
        if l3!=l:
            l4, v4 = lca(l3, y)
            assert l4==l
            res = v3 + v4
        else:
            l4,v4 = lca(y,z)
            l5, v5 = lca(l4, x)
            assert l==l5
            res = v4+v5
    ans.append(res)
write("\n".join(map(str, ans)))
0