結果

問題 No.386 貪欲な領主
ユーザー tktk_snsntktk_snsn
提出日時 2021-01-11 17:23:35
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 491 ms / 2,000 ms
コード長 1,492 bytes
コンパイル時間 166 ms
コンパイル使用メモリ 82,412 KB
実行使用メモリ 138,864 KB
最終ジャッジ日時 2024-11-21 09:17:28
合計ジャッジ時間 4,082 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 35 ms
53,272 KB
testcase_01 AC 35 ms
53,088 KB
testcase_02 AC 39 ms
53,196 KB
testcase_03 AC 37 ms
53,048 KB
testcase_04 AC 450 ms
138,864 KB
testcase_05 AC 456 ms
132,348 KB
testcase_06 AC 491 ms
131,760 KB
testcase_07 AC 82 ms
77,048 KB
testcase_08 AC 142 ms
83,620 KB
testcase_09 AC 89 ms
77,652 KB
testcase_10 AC 37 ms
52,948 KB
testcase_11 AC 36 ms
53,288 KB
testcase_12 AC 74 ms
76,652 KB
testcase_13 AC 100 ms
78,380 KB
testcase_14 AC 473 ms
131,948 KB
testcase_15 AC 381 ms
131,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.buffer.readline
sys.setrecursionlimit(10 ** 7)

n = int(input())
edge = [[] for _ in range(n)]
for _ in range(n - 1):
    x, y = map(int, input().split())
    edge[x].append(y)
    edge[y].append(x)
U = [int(input()) for _ in range(n)]
M = int(input())
abc = tuple(tuple(map(int, input().split())) for _ in range(M))

D = n.bit_length()
par = [[-1] * n for _ in range(D)]
depth = [0] * n
topo = []
que = [0]
while que:
    s = que.pop()
    topo.append(s)
    for t in edge[s]:
        if t == par[0][s]:
            continue
        depth[t] = depth[s] + 1
        par[0][t] = s
        que.append(t)

for i in range(D-1):
    for j in range(n):
        par[i + 1][j] = par[i][par[i][j]]


def lowest_ancestor(x, h):  # xよりh上にあるノード番号を返す
    for i in reversed(range(D)):
        if h >= (1 << i):
            x = par[i][x]
            h -= (1 << i)
    return x


def LCA(x, y):
    if depth[x] < depth[y]:
        x, y = y, x
    x = lowest_ancestor(x, depth[x] - depth[y])
    if x == y:
        return x
    for i in reversed(range(D)):
        if par[i][x] != par[i][y]:
            x = par[i][x]
            y = par[i][y]
    return par[0][x]


cnt = [0] * n
for a, b, c in abc:
    x = LCA(a, b)
    cnt[a] += c
    cnt[b] += c
    cnt[x] -= c
    p = par[0][x]
    if p != -1:
        cnt[p] -= c

ans = 0
for s in topo[::-1][:-1]:
    ans += U[s] * cnt[s]
    p = par[0][s]
    cnt[p] += cnt[s]
ans += U[0] * cnt[0]

print(ans)
0