結果
問題 | No.386 貪欲な領主 |
ユーザー | rpy3cpp |
提出日時 | 2016-07-01 23:32:03 |
言語 | PyPy2 (7.3.15) |
結果 |
AC
|
実行時間 | 1,373 ms / 2,000 ms |
コード長 | 3,109 bytes |
コンパイル時間 | 370 ms |
コンパイル使用メモリ | 77,056 KB |
実行使用メモリ | 356,456 KB |
最終ジャッジ日時 | 2024-10-12 19:08:52 |
合計ジャッジ時間 | 9,947 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 82 ms
75,648 KB |
testcase_01 | AC | 84 ms
75,648 KB |
testcase_02 | AC | 83 ms
75,520 KB |
testcase_03 | AC | 83 ms
75,264 KB |
testcase_04 | AC | 1,373 ms
356,456 KB |
testcase_05 | AC | 1,241 ms
187,464 KB |
testcase_06 | AC | 1,345 ms
189,980 KB |
testcase_07 | AC | 119 ms
79,448 KB |
testcase_08 | AC | 307 ms
99,584 KB |
testcase_09 | AC | 141 ms
81,408 KB |
testcase_10 | AC | 82 ms
75,520 KB |
testcase_11 | AC | 81 ms
75,392 KB |
testcase_12 | AC | 117 ms
79,360 KB |
testcase_13 | AC | 210 ms
83,584 KB |
testcase_14 | AC | 858 ms
175,552 KB |
testcase_15 | AC | 1,189 ms
320,368 KB |
ソースコード
import sys sys.setrecursionlimit(10**6) input = raw_input range = xrange def read_data(): N = int(input()) Es = [[] for i in range(N)] for i in range(N - 1): a, b = map(int, input().split()) Es[a].append(b) Es[b].append(a) Us = [int(input()) for i in range(N)] M = int(input()) moves = [list(map(int, input().split())) for m in range(M)] return N, Es, Us, M, moves class DisjointSet(): def __init__(self, n): self.parent = list(range(n)) self.rank = [0] * n def union(self, x, y): self._link(self.find_set(x), self.find_set(y)) def _link(self, x, y): if self.rank[x] > self.rank[y]: self.parent[y] = x else: self.parent[x] = y if self.rank[x] == self.rank[y]: self.rank[y] += 1 def find_set(self, x): xp = self.parent[x] if xp != x: self.parent[x] = self.find_set(xp) return self.parent[x] class Tree(): def __init__(self, N, Es, root, Us): self.n = N self.root = root self.child = [[] for i in range(N)] self.cum_cost = [0 for i in range(N)] self._set_child(Es, Us) def _set_child(self, Es, Us): que = [self.root] visited = [False] * self.n self.cum_cost[self.root] = Us[self.root] while que: v = que.pop() cum_cost_v = self.cum_cost[v] for u in Es[v]: if visited[u]: continue self.child[v].append(u) self.cum_cost[u] = cum_cost_v + Us[u] que.append(u) visited[v] = True class LCATarjan(): def __init__(self, tree): self.n = tree.n self.root = tree.root self.child = tree.child self.ancestor = list(range(self.n)) self.visited = [False] * self.n self.ds = DisjointSet(self.n) self.lca = dict() def find(self, pairs): self._preprocess(pairs) self._LCA(self.root) return self.lca def _preprocess(self, pairs): self.pairs = [[] for node in range(self.n)] for u, v in pairs: self.pairs[u].append(v) self.pairs[v].append(u) def _LCA(self, u): self.ancestor[self.ds.find_set(u)] = u for v in self.child[u]: self._LCA(v) self.ds.union(u, v) self.ancestor[self.ds.find_set(u)] = u self.visited[u] = True for v in self.pairs[u]: if self.visited[v]: self.lca[u, v] = self.ancestor[self.ds.find_set(v)] self.lca[v, u] = self.lca[u, v] def solve(N, Es, Us, M, moves): tree = Tree(N, Es, 0, Us) cum_cost = tree.cum_cost pairs = [(a, b) for a, b, c in moves] lcat = LCATarjan(tree) lca = lcat.find(pairs) tax = 0 for a, b, c in moves: v = lca[a, b] tax += (cum_cost[a] + cum_cost[b] - cum_cost[v] * 2 + Us[v]) * c return tax pars = read_data() print(solve(*pars))