結果

問題 No.386 貪欲な領主
ユーザー tktk_snsntktk_snsn
提出日時 2021-01-11 17:14:40
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 469 ms / 2,000 ms
コード長 1,498 bytes
コンパイル時間 128 ms
コンパイル使用メモリ 82,268 KB
実行使用メモリ 138,612 KB
最終ジャッジ日時 2024-05-01 06:52:18
合計ジャッジ時間 4,427 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 32 ms
53,252 KB
testcase_01 AC 36 ms
52,460 KB
testcase_02 AC 38 ms
52,704 KB
testcase_03 AC 36 ms
53,164 KB
testcase_04 AC 434 ms
138,612 KB
testcase_05 AC 446 ms
132,504 KB
testcase_06 AC 469 ms
131,732 KB
testcase_07 AC 76 ms
76,992 KB
testcase_08 AC 134 ms
83,420 KB
testcase_09 AC 81 ms
76,920 KB
testcase_10 AC 32 ms
52,532 KB
testcase_11 AC 32 ms
53,084 KB
testcase_12 AC 68 ms
76,572 KB
testcase_13 AC 91 ms
78,452 KB
testcase_14 AC 452 ms
131,996 KB
testcase_15 AC 352 ms
131,464 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.buffer.readline
sys.setrecursionlimit(10 ** 7)

n = int(input())
edge = [[] for _ in range(n)]
for _ in range(n - 1):
    x, y = map(int, input().split())
    edge[x].append(y)
    edge[y].append(x)
U = [int(input()) for _ in range(n)]
M = int(input())
abc = tuple(tuple(map(int, input().split())) for _ in range(M))

D = n.bit_length()
par = [[-1] * (n + 1) for _ in range(D)]
depth = [0] * n
topo = []
que = [0]
while que:
    s = que.pop()
    topo.append(s)
    for t in edge[s]:
        if t == par[0][s]:
            continue
        depth[t] = depth[s] + 1
        par[0][t] = s
        que.append(t)

for i in range(D-1):
    for j in range(n):
        par[i + 1][j] = par[i][par[i][j]]


def lowest_ancestor(x, h):  # xよりh上にあるノード番号を返す
    for i in reversed(range(D)):
        if h >= (1 << i):
            x = par[i][x]
            h -= (1 << i)
    return x


def LCA(x, y):
    if depth[x] < depth[y]:
        x, y = y, x
    x = lowest_ancestor(x, depth[x] - depth[y])
    if x == y:
        return x
    for i in reversed(range(D)):
        if par[i][x] != par[i][y]:
            x = par[i][x]
            y = par[i][y]
    return par[0][x]


cnt = [0] * n
for a, b, c in abc:
    x = LCA(a, b)
    cnt[a] += c
    cnt[b] += c
    cnt[x] -= c
    p = par[0][x]
    if p != -1:
        cnt[p] -= c

ans = 0
for s in topo[::-1][:-1]:
    ans += U[s] * cnt[s]
    p = par[0][s]
    cnt[p] += cnt[s]
ans += U[0] * cnt[0]

print(ans)
0