結果

問題 No.386 貪欲な領主
ユーザー matsu7874matsu7874
提出日時 2016-07-19 02:08:11
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,883 bytes
コンパイル時間 185 ms
コンパイル使用メモリ 82,416 KB
実行使用メモリ 159,856 KB
最終ジャッジ日時 2024-04-23 16:38:44
合計ジャッジ時間 4,175 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 48 ms
53,120 KB
testcase_01 AC 42 ms
52,736 KB
testcase_02 AC 45 ms
53,760 KB
testcase_03 AC 44 ms
53,888 KB
testcase_04 TLE -
testcase_05 TLE -
testcase_06 TLE -
testcase_07 AC 203 ms
77,856 KB
testcase_08 AC 815 ms
92,068 KB
testcase_09 AC 412 ms
80,328 KB
testcase_10 AC 44 ms
52,992 KB
testcase_11 AC 42 ms
52,992 KB
testcase_12 AC 213 ms
77,696 KB
testcase_13 AC 317 ms
80,472 KB
testcase_14 TLE -
testcase_15 TLE -
権限があれば一括ダウンロードができます

ソースコード

diff #

class SegmentTree():

    def __init__(self, size, initial=0):
        if isinstance(initial, list):
            node_size = 1
            size = len(initial)
            while node_size < size:
                node_size <<= 1
            self.st = [(float('inf'), float('inf'))] * (2 * node_size - 1)
            for i in range(size):
                self.st[node_size - 1 + i] = (initial[i], i)
            i = node_size * 2 - 2
            while i > 1:
                self.st[(i - 1) // 2] = min(self.st[i - 1], self.st[i])
                i -= 2
            self.size = node_size

        else:
            node_size = 1
            while node_size < size:
                node_size <<= 1
            self.st = [initial] * (2 * node_size - 1)
            self.size = node_size

    def update(self, target, new_value):
        target += self.size - 1
        self.st[target] = a
        while target > 0:
            target = (target - 1) // 2
            self.st[target] = min(self.st[target * 2 + 1],
                                  self.st[target * 2 + 2])

    #[a,b)の最小値を求める。RMQを解く。
    def query(self, a, b):
        return self._query(a, b, 0, 0, self.size)

    def _query(self, a, b, k, l, r):
        if r <= a or b <= l:
            return (float('inf'), -1)
        if a <= l and r <= b:
            return self.st[k]
        vl = self._query(a, b, k * 2 + 1, l, (l + r) // 2)
        vr = self._query(a, b, k * 2 + 2, (l + r) // 2, r)
        return min(vl, vr)


class Tree():

    def __init__(self, size):
        self.size = size
        self.edges = [[] for i in range(size)]

    def create_nodes(self, s, t):
        self.edges[s].append(t)

    def create_nodes_bidirectional(self, s, t):
        self.edges[s].append(t)
        self.edges[t].append(s)

    def build(self, cost, root=0):
        self.route = []
        depth = [0] * self.size
        visited = [False] * self.size
        self.cost = [0] * self.size
        self.cost[root] = cost[root]
        cur = root
        d = 0
        c = cost[root]
        stack = []
        while True:
            self.route.append(cur)
            depth[cur] = d
            self.cost[cur] = c
            visited[cur] = True
            stack.append(cur)
            for child in self.edges[cur]:
                if visited[child]:
                    continue
                cur = child
                break
            if cur == stack[-1]:
                c -= cost[cur]
                stack.pop()
                d -= 1
                if stack:
                    cur = stack.pop()
                else:
                    break
            else:
                c += cost[cur]
                d += 1

        route_depth = [depth[self.route[i]] for i in range(len(self.route))]
        self.rmq = SegmentTree(self.size, initial=route_depth)
        self.first_occur = [-1] * self.size
        for i in range(len(self.route)):
            if self.first_occur[self.route[i]] == -1:
                self.first_occur[self.route[i]] = i

    def query(self, a, b):
        first = min(self.first_occur[a], self.first_occur[b])
        second = max(self.first_occur[a], self.first_occur[b])
        min_v, min_i = self.rmq.query(first, second)
        return self.route[min_i]


n = int(input())
tree = Tree(n)
for i in range(n - 1):
    a, b = map(int, input().split())
    tree.create_nodes_bidirectional(a, b)

tariff = [0] * n
for i in range(n):
    tariff[i] = int(input())
tree.build(tariff)

total = 0
m = int(input())
memo = {}
for i in range(m):
    a, b, c = map(int, input().split())
    if a > b:
        a, b = b, a
    if (a, b) in memo:
        total += memo[(a, b)] * c
    else:
        lca = tree.query(a, b)
        memo[(a, b)] = tree.cost[a] + tree.cost[b] - \
            2 * tree.cost[lca] + tariff[lca]
        total += memo[(a, b)] * c
print(total)
0