結果

問題 No.386 貪欲な領主
ユーザー rpy3cpprpy3cpp
提出日時 2016-07-02 00:11:30
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 5,214 bytes
コンパイル時間 155 ms
コンパイル使用メモリ 81,864 KB
実行使用メモリ 291,200 KB
最終ジャッジ日時 2024-10-12 19:18:02
合計ジャッジ時間 6,195 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 44 ms
54,016 KB
testcase_01 AC 42 ms
53,376 KB
testcase_02 AC 43 ms
53,888 KB
testcase_03 AC 43 ms
54,016 KB
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 AC 112 ms
77,696 KB
testcase_08 AC 228 ms
88,192 KB
testcase_09 AC 123 ms
78,336 KB
testcase_10 AC 41 ms
53,504 KB
testcase_11 AC 40 ms
53,888 KB
testcase_12 AC 98 ms
76,928 KB
testcase_13 AC 172 ms
80,768 KB
testcase_14 WA -
testcase_15 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
sys.setrecursionlimit(10**6)

#input = raw_input
#range = xrange


def read_data():
    N = int(input())
    Es = [[] for i in range(N)]
    for i in range(N - 1):
        a, b = map(int, input().split())
        Es[a].append(b)
        Es[b].append(a)
    Us = [int(input()) for i in range(N)]
    M = int(input())
    moves = [list(map(int, input().split())) for m in range(M)]
    return N, Es, Us, M, moves


class DisjointSet():
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
    
    def union(self, x, y):
        self._link(self.find_set(x), self.find_set(y))
    
    def _link(self, x, y):
        if self.rank[x] > self.rank[y]:
            self.parent[y] = x
        else:
            self.parent[x] = y
            if self.rank[x] == self.rank[y]:
                self.rank[y] += 1
    
    def find_set(self, x):
        xp = self.parent[x]
        if xp != x:
            self.parent[x] = self.find_set(xp)
        return self.parent[x]

class Tree():
    def __init__(self, N, Es, root, Us):
        self.n = N
        self.root = root
        self.child = [[] for i in range(N)]
        self.cum_cost = [0 for i in range(N)]
        self._set_child(Es, Us)
    
    def _set_child(self, Es, Us):
        que = [self.root]
        visited = [False] * self.n
        self.cum_cost[self.root] = Us[self.root]
        while que:
            v = que.pop()
            cum_cost_v = self.cum_cost[v]
            for u in Es[v]:
                if visited[u]:
                    continue
                self.child[v].append(u)
                self.cum_cost[u] = cum_cost_v + Us[u]
                que.append(u)
            visited[v] = True        

class LCArmq():
    def __init__(self, tree):
        D, E, R = self._convert_to_RMQ(tree.child, tree.root, tree.n)
        self._euler = E
        self._reverse = R
        self._RMQ = RMQ(D)

    def _convert_to_RMQ(self, child, root, n):
        ''' LCA の前処理。 RMQ に置き換えるため、Euler tour で巡回して深さのリストをつくる。
        '''
        depth = []
        euler = []
        reverse = [0] * n
        
        def euler_tour(node, d):
            nonlocal depth, euler
            for v in child[node]:
                euler.append(node)
                depth.append(d)
                euler_tour(v, d + 1)
            euler.append(node)
            depth.append(d)

        euler_tour(root, 0)
        for i, node in enumerate(euler):
            reverse[node] = i
        return depth, euler, reverse

    def query(self, v, w):
        i, j = self._reverse[v], self._reverse[w]
        rmq = self._RMQ.query(i, j)
        lca = self._euler[rmq]
        return lca


class RMQ(object):
    def __init__(self, iterable):
        if len(iterable) < 10**5:
            self._RMQ = RMQdoubling(iterable)
        else:
            self._RMQ = RMQfaster(iterable)

    def query(self, i, j):
        return self._RMQ.query(i, j)


class RMQdoubling(RMQ):
    def __init__(self, A):
        self._A = A
        self._preprocess()

    def _preprocess(self):
        ''' RMQ の前処理。
        '''
        n = len(self._A)
        max_j = int(math.log2(n))
        self._M = [list(range(n))]
        for j in range(0, max_j):
            shift = 1 << j
            Mj = self._M[j]
            Mjnext = []
            for k1, k2 in zip(Mj, Mj[shift:]):
                k = k1 if self._A[k1] < self._A[k2] else k2
                Mjnext.append(k)
            self._M.append(Mjnext)

    def query(self, i, j):
        if i == j: return i
        if i > j: i, j = j, i
        el = int(math.log2(j - i))
        k1 = self._M[el][i]
        k2 = self._M[el][j - (1 << el) + 1]
        rmq = k1 if self._A[k1] < self._A[k2] else k2
        return rmq

class RMQfaster(RMQdoubling):
    def __init__(self, D):
        self._D = D
        A, self._block_size = self._chop()
        super().__init__(A)

    def _chop(self):
        n = len(self._D)
        block_size = int(math.log2(n)/4)
        A = [min(self._D[i:i+block_size]) for i in range(0, n, block_size)]
        return A, block_size

    def _findmin(self, d_min, start, stop):
        for i, d in enumerate(self._D[start:stop], start):
            if d == d_min:
                return i

    def query(self, i, j):
        if i == j: return i
        if i > j:
            i, j = j, i
        s = self._block_size
        ii = (i - 1)//s + 1
        jj = (j - 1)//s
        mid_block = super().query(ii, jj)
        d_mid = self._A[mid_block]
        d_min = d_mid
        for k in list(range(i, ii*s)) + list(range(jj*s, j+1)):
            if self._D[k] < d_min:
                d_min = self._D[k]
                k_min = k
        if d_min < d_mid:
            return k_min
        else:
            return self._findmin(d_min, mid_block*s, (mid_block+1)*s)
    

def solve(N, Es, Us, M, moves):
    tree = Tree(N, Es, 0, Us)
    cum_cost = tree.cum_cost
    lca_rmq = LCArmq(tree)
    tax = 0
    for a, b, c in moves:
        v = lca_rmq.query(a, b)
        tax += (cum_cost[a] + cum_cost[b] - cum_cost[v] * 2 + Us[v]) * c
    return tax

pars = read_data()
print(solve(*pars))
0