結果
問題 |
No.386 貪欲な領主
|
ユーザー |
![]() |
提出日時 | 2021-01-11 17:23:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 491 ms / 2,000 ms |
コード長 | 1,492 bytes |
コンパイル時間 | 166 ms |
コンパイル使用メモリ | 82,412 KB |
実行使用メモリ | 138,864 KB |
最終ジャッジ日時 | 2024-11-21 09:17:28 |
合計ジャッジ時間 | 4,082 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 16 |
ソースコード
import sys input = sys.stdin.buffer.readline sys.setrecursionlimit(10 ** 7) n = int(input()) edge = [[] for _ in range(n)] for _ in range(n - 1): x, y = map(int, input().split()) edge[x].append(y) edge[y].append(x) U = [int(input()) for _ in range(n)] M = int(input()) abc = tuple(tuple(map(int, input().split())) for _ in range(M)) D = n.bit_length() par = [[-1] * n for _ in range(D)] depth = [0] * n topo = [] que = [0] while que: s = que.pop() topo.append(s) for t in edge[s]: if t == par[0][s]: continue depth[t] = depth[s] + 1 par[0][t] = s que.append(t) for i in range(D-1): for j in range(n): par[i + 1][j] = par[i][par[i][j]] def lowest_ancestor(x, h): # xよりh上にあるノード番号を返す for i in reversed(range(D)): if h >= (1 << i): x = par[i][x] h -= (1 << i) return x def LCA(x, y): if depth[x] < depth[y]: x, y = y, x x = lowest_ancestor(x, depth[x] - depth[y]) if x == y: return x for i in reversed(range(D)): if par[i][x] != par[i][y]: x = par[i][x] y = par[i][y] return par[0][x] cnt = [0] * n for a, b, c in abc: x = LCA(a, b) cnt[a] += c cnt[b] += c cnt[x] -= c p = par[0][x] if p != -1: cnt[p] -= c ans = 0 for s in topo[::-1][:-1]: ans += U[s] * cnt[s] p = par[0][s] cnt[p] += cnt[s] ans += U[0] * cnt[0] print(ans)