結果
問題 | No.386 貪欲な領主 |
ユーザー | rpy3cpp |
提出日時 | 2016-07-02 00:18:41 |
言語 | PyPy2 (7.3.15) |
結果 |
RE
|
実行時間 | - |
コード長 | 5,214 bytes |
コンパイル時間 | 388 ms |
コンパイル使用メモリ | 76,960 KB |
実行使用メモリ | 258,744 KB |
最終ジャッジ日時 | 2024-10-12 19:33:06 |
合計ジャッジ時間 | 5,333 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | RE | - |
testcase_01 | RE | - |
testcase_02 | RE | - |
testcase_03 | RE | - |
testcase_04 | RE | - |
testcase_05 | RE | - |
testcase_06 | RE | - |
testcase_07 | RE | - |
testcase_08 | RE | - |
testcase_09 | RE | - |
testcase_10 | RE | - |
testcase_11 | RE | - |
testcase_12 | RE | - |
testcase_13 | RE | - |
testcase_14 | RE | - |
testcase_15 | RE | - |
ソースコード
import sys import math sys.setrecursionlimit(10**6) input = raw_input range = xrange def read_data(): N = int(input()) Es = [[] for i in range(N)] for i in range(N - 1): a, b = map(int, input().split()) Es[a].append(b) Es[b].append(a) Us = [int(input()) for i in range(N)] M = int(input()) moves = [list(map(int, input().split())) for m in range(M)] return N, Es, Us, M, moves class DisjointSet(): def __init__(self, n): self.parent = list(range(n)) self.rank = [0] * n def union(self, x, y): self._link(self.find_set(x), self.find_set(y)) def _link(self, x, y): if self.rank[x] > self.rank[y]: self.parent[y] = x else: self.parent[x] = y if self.rank[x] == self.rank[y]: self.rank[y] += 1 def find_set(self, x): xp = self.parent[x] if xp != x: self.parent[x] = self.find_set(xp) return self.parent[x] class Tree(): def __init__(self, N, Es, root, Us): self.n = N self.root = root self.child = [[] for i in range(N)] self.cum_cost = [0 for i in range(N)] self._set_child(Es, Us) def _set_child(self, Es, Us): que = [self.root] visited = [False] * self.n self.cum_cost[self.root] = Us[self.root] while que: v = que.pop() cum_cost_v = self.cum_cost[v] for u in Es[v]: if visited[u]: continue self.child[v].append(u) self.cum_cost[u] = cum_cost_v + Us[u] que.append(u) visited[v] = True class LCArmq(): def __init__(self, tree): D, E, R = self._convert_to_RMQ(tree.child, tree.root, tree.n) self._euler = E self._reverse = R self._RMQ = RMQ(D) def _convert_to_RMQ(self, child, root, n): ''' LCA の前処理。 RMQ に置き換えるため、Euler tour で巡回して深さのリストをつくる。 ''' depth = [] euler = [] reverse = [0] * n def euler_tour(node, d, depth, euler): for v in child[node]: euler.append(node) depth.append(d) euler_tour(v, d + 1, depth, euler) euler.append(node) depth.append(d) euler_tour(root, 0, depth, euler) for i, node in enumerate(euler): reverse[node] = i return depth, euler, reverse def query(self, v, w): i, j = self._reverse[v], self._reverse[w] rmq = self._RMQ.query(i, j) lca = self._euler[rmq] return lca class RMQ(): def __init__(self, iterable): if len(iterable) < 10**5: self._RMQ = RMQdoubling(iterable) else: self._RMQ = RMQfaster(iterable) def query(self, i, j): return self._RMQ.query(i, j) class RMQdoubling(RMQ): def __init__(self, A): self._A = A self._preprocess() def _preprocess(self): ''' RMQ の前処理。 ''' n = len(self._A) max_j = int(math.log2(n)) self._M = [list(range(n))] for j in range(0, max_j): shift = 1 << j Mj = self._M[j] Mjnext = [] for k1, k2 in zip(Mj, Mj[shift:]): k = k1 if self._A[k1] < self._A[k2] else k2 Mjnext.append(k) self._M.append(Mjnext) def query(self, i, j): if i == j: return i if i > j: i, j = j, i el = int(math.log2(j - i)) k1 = self._M[el][i] k2 = self._M[el][j - (1 << el) + 1] rmq = k1 if self._A[k1] < self._A[k2] else k2 return rmq class RMQfaster(RMQdoubling): def __init__(self, D): self._D = D A, self._block_size = self._chop() super().__init__(A) def _chop(self): n = len(self._D) block_size = int(math.log2(n)/4) A = [min(self._D[i:i+block_size]) for i in range(0, n, block_size)] return A, block_size def _findmin(self, d_min, start, stop): for i, d in enumerate(self._D[start:stop], start): if d == d_min: return i def query(self, i, j): if i == j: return i if i > j: i, j = j, i s = self._block_size ii = (i - 1)//s + 1 jj = (j - 1)//s mid_block = super().query(ii, jj) d_mid = self._A[mid_block] d_min = d_mid for k in list(range(i, ii*s)) + list(range(jj*s, j+1)): if self._D[k] < d_min: d_min = self._D[k] k_min = k if d_min < d_mid: return k_min else: return self._findmin(d_min, mid_block*s, (mid_block+1)*s) def solve(N, Es, Us, M, moves): tree = Tree(N, Es, 0, Us) cum_cost = tree.cum_cost lca_rmq = LCArmq(tree) tax = 0 for a, b, c in moves: v = lca_rmq.query(a, b) tax += (cum_cost[a] + cum_cost[b] - cum_cost[v] * 2 + Us[v]) * c return tax pars = read_data() print(solve(*pars))