結果

問題 No.386 貪欲な領主
ユーザー brthyyjpbrthyyjp
提出日時 2020-12-22 16:58:10
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 579 ms / 2,000 ms
コード長 3,712 bytes
コンパイル時間 180 ms
コンパイル使用メモリ 81,764 KB
実行使用メモリ 135,576 KB
最終ジャッジ日時 2023-10-21 13:00:11
合計ジャッジ時間 5,818 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 39 ms
53,556 KB
testcase_01 AC 36 ms
53,556 KB
testcase_02 AC 36 ms
53,556 KB
testcase_03 AC 36 ms
53,556 KB
testcase_04 AC 429 ms
135,576 KB
testcase_05 AC 546 ms
122,976 KB
testcase_06 AC 579 ms
124,668 KB
testcase_07 AC 109 ms
76,644 KB
testcase_08 AC 239 ms
82,348 KB
testcase_09 AC 136 ms
77,056 KB
testcase_10 AC 36 ms
53,556 KB
testcase_11 AC 37 ms
53,556 KB
testcase_12 AC 87 ms
76,100 KB
testcase_13 AC 164 ms
78,604 KB
testcase_14 AC 552 ms
123,496 KB
testcase_15 AC 399 ms
134,912 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class SegTree:
    def __init__(self, init_val, ide_ele, segfunc):
        self.n = len(init_val)
        self.num = 2**(self.n-1).bit_length()
        self.ide_ele = ide_ele
        self.segfunc = segfunc
        self.seg = [ide_ele]*2*self.num
        # set_val
        for i in range(self.n):
            self.seg[i+self.num] = init_val[i]
        # built
        for i in range(self.num-1, 0, -1):
            self.seg[i] = self.segfunc(self.seg[2*i], self.seg[2*i+1])

    def update(self, k, x):
        k += self.num
        self.seg[k] = x
        while k:
            k = k >> 1
            self.seg[k] = self.segfunc(self.seg[2*k], self.seg[2*k+1])

    def query(self, l, r):
        if r <= l:
            return self.ide_ele
        l += self.num
        r += self.num
        res = self.ide_ele
        while l < r:
            if r & 1:
                r -= 1
                res = self.segfunc(res, self.seg[r])
            if l & 1:
                res = self.segfunc(res, self.seg[l])
                l += 1
            l = l >> 1
            r = r >> 1
        return res

def segfunc(x, y):
    if x <= y:
        return x
    else:
        return y

ide_ele = 10**18

class LCA:
    def __init__(self, g, root):
        # g: adjacency list
        # root
        self.n = len(g)
        self.root = root

        s = [self.root]
        self.parent = [-1]*self.n
        self.child = [[] for _ in range(self.n)]
        self.order = []
        visit = [-1]*self.n
        visit[self.root] = 0
        while s:
            v = s.pop()
            self.order.append(v)
            for u in g[v]:
                if visit[u] == -1:
                    self.parent[u] = v
                    self.child[v].append(u)
                    visit[u] = 0
                    s.append(u)
        self.order.reverse()

        # Euler tour
        tank = [self.root]
        self.eulerTour = []
        self.left = [0]*self.n
        self.right = [-1]*self.n
        self.depth = [-1]*self.n

        eulerNum = -1
        de = -1

        while tank:
            v = tank.pop()
            if v >= 0:
                eulerNum += 1
                self.eulerTour.append(v)
                self.left[v] = eulerNum
                self.right[v] = eulerNum
                tank.append(~v)
                de += 1
                self.depth[v] = de
                for u in self.child[v]:
                    tank.append(u)
            else:
                de -= 1
                if ~v != self.root:
                    self.eulerTour.append(self.parent[~v])
                    eulerNum += 1
                    self.right[self.parent[~v]] = eulerNum

        #A = [self.depth[e] for e in self.eulerTour]
        A = [0]*(2*self.n-1)
        for i, e in enumerate(self.eulerTour):
            A[i] = self.depth[e]*(2*self.n-1)+i
        self.seg = SegTree(A, ide_ele, segfunc)

    def getLCA(self, u, v):
        # u, v: 0-indexed
        p = min(self.left[u], self.left[v])
        q = max(self.right[u], self.left[v])+1
        m = self.seg.query(p, q)
        return self.eulerTour[m%(2*self.n-1)]

import sys
import io, os
input = io.BytesIO(os.read(0,os.fstat(0).st_size)).readline

n = int(input())
g = [[] for i in range(n)]
for i in range(n-1):
    a, b = map(int, input().split())
    g[a].append(b)
    g[b].append(a)

U = [int(input()) for i in range(n)]

lca = LCA(g, 0)
order = lca.order
parent = lca.parent

m = int(input())
C = [0]*(n+1)
for i in range(m):
    a, b, c = map(int, input().split())
    l = lca.getLCA(a, b)
    C[a] += c
    C[b] += c
    C[l] -= c
    C[parent[l]] -= c

for v in order:
    C[parent[v]] += C[v]
ans = 0
for i in range(n):
    ans += U[i]*C[i]
print(ans)
0