結果
| 問題 |
No.386 貪欲な領主
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2016-07-02 00:22:10 |
| 言語 | PyPy2 (7.3.15) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 5,219 bytes |
| コンパイル時間 | 2,229 ms |
| コンパイル使用メモリ | 76,672 KB |
| 実行使用メモリ | 259,328 KB |
| 最終ジャッジ日時 | 2024-10-12 19:36:23 |
| 合計ジャッジ時間 | 6,096 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 11 RE * 5 |
ソースコード
import sys
import math
sys.setrecursionlimit(10**6)
input = raw_input
range = xrange
def read_data():
N = int(input())
Es = [[] for i in range(N)]
for i in range(N - 1):
a, b = map(int, input().split())
Es[a].append(b)
Es[b].append(a)
Us = [int(input()) for i in range(N)]
M = int(input())
moves = [list(map(int, input().split())) for m in range(M)]
return N, Es, Us, M, moves
class DisjointSet():
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def union(self, x, y):
self._link(self.find_set(x), self.find_set(y))
def _link(self, x, y):
if self.rank[x] > self.rank[y]:
self.parent[y] = x
else:
self.parent[x] = y
if self.rank[x] == self.rank[y]:
self.rank[y] += 1
def find_set(self, x):
xp = self.parent[x]
if xp != x:
self.parent[x] = self.find_set(xp)
return self.parent[x]
class Tree():
def __init__(self, N, Es, root, Us):
self.n = N
self.root = root
self.child = [[] for i in range(N)]
self.cum_cost = [0 for i in range(N)]
self._set_child(Es, Us)
def _set_child(self, Es, Us):
que = [self.root]
visited = [False] * self.n
self.cum_cost[self.root] = Us[self.root]
while que:
v = que.pop()
cum_cost_v = self.cum_cost[v]
for u in Es[v]:
if visited[u]:
continue
self.child[v].append(u)
self.cum_cost[u] = cum_cost_v + Us[u]
que.append(u)
visited[v] = True
class LCArmq():
def __init__(self, tree):
D, E, R = self._convert_to_RMQ(tree.child, tree.root, tree.n)
self._euler = E
self._reverse = R
self._RMQ = RMQ(D)
def _convert_to_RMQ(self, child, root, n):
''' LCA の前処理。 RMQ に置き換えるため、Euler tour で巡回して深さのリストをつくる。
'''
depth = []
euler = []
reverse = [0] * n
def euler_tour(node, d, depth, euler):
for v in child[node]:
euler.append(node)
depth.append(d)
euler_tour(v, d + 1, depth, euler)
euler.append(node)
depth.append(d)
euler_tour(root, 0, depth, euler)
for i, node in enumerate(euler):
reverse[node] = i
return depth, euler, reverse
def query(self, v, w):
i, j = self._reverse[v], self._reverse[w]
rmq = self._RMQ.query(i, j)
lca = self._euler[rmq]
return lca
class RMQ():
def __init__(self, iterable):
if len(iterable) < 10**5:
self._RMQ = RMQdoubling(iterable)
else:
self._RMQ = RMQfaster(iterable)
def query(self, i, j):
return self._RMQ.query(i, j)
class RMQdoubling(RMQ):
def __init__(self, A):
self._A = A
self._preprocess()
def _preprocess(self):
''' RMQ の前処理。
'''
n = len(self._A)
max_j = int(math.log(n, 2))
self._M = [list(range(n))]
for j in range(0, max_j):
shift = 1 << j
Mj = self._M[j]
Mjnext = []
for k1, k2 in zip(Mj, Mj[shift:]):
k = k1 if self._A[k1] < self._A[k2] else k2
Mjnext.append(k)
self._M.append(Mjnext)
def query(self, i, j):
if i == j: return i
if i > j: i, j = j, i
el = int(math.log(j - i, 2))
k1 = self._M[el][i]
k2 = self._M[el][j - (1 << el) + 1]
rmq = k1 if self._A[k1] < self._A[k2] else k2
return rmq
class RMQfaster(RMQdoubling):
def __init__(self, D):
self._D = D
A, self._block_size = self._chop()
super().__init__(A)
def _chop(self):
n = len(self._D)
block_size = int(math.log(n, 2)/4)
A = [min(self._D[i:i+block_size]) for i in range(0, n, block_size)]
return A, block_size
def _findmin(self, d_min, start, stop):
for i, d in enumerate(self._D[start:stop], start):
if d == d_min:
return i
def query(self, i, j):
if i == j: return i
if i > j:
i, j = j, i
s = self._block_size
ii = (i - 1)//s + 1
jj = (j - 1)//s
mid_block = super().query(ii, jj)
d_mid = self._A[mid_block]
d_min = d_mid
for k in list(range(i, ii*s)) + list(range(jj*s, j+1)):
if self._D[k] < d_min:
d_min = self._D[k]
k_min = k
if d_min < d_mid:
return k_min
else:
return self._findmin(d_min, mid_block*s, (mid_block+1)*s)
def solve(N, Es, Us, M, moves):
tree = Tree(N, Es, 0, Us)
cum_cost = tree.cum_cost
lca_rmq = LCArmq(tree)
tax = 0
for a, b, c in moves:
v = lca_rmq.query(a, b)
tax += (cum_cost[a] + cum_cost[b] - cum_cost[v] * 2 + Us[v]) * c
return tax
pars = read_data()
print(solve(*pars))