結果

問題 No.399 動的な領主
ユーザー rpy3cpprpy3cpp
提出日時 2016-07-16 00:38:05
言語 PyPy2
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,453 bytes
コンパイル時間 279 ms
コンパイル使用メモリ 77,100 KB
実行使用メモリ 388,472 KB
最終ジャッジ日時 2024-04-23 13:03:08
合計ジャッジ時間 11,886 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 75 ms
76,032 KB
testcase_01 AC 78 ms
76,032 KB
testcase_02 AC 84 ms
76,956 KB
testcase_03 AC 82 ms
77,100 KB
testcase_04 AC 115 ms
79,032 KB
testcase_05 AC 255 ms
91,820 KB
testcase_06 AC 929 ms
218,772 KB
testcase_07 AC 941 ms
218,176 KB
testcase_08 AC 853 ms
217,720 KB
testcase_09 AC 836 ms
217,508 KB
testcase_10 AC 142 ms
80,384 KB
testcase_11 AC 203 ms
89,528 KB
testcase_12 AC 628 ms
199,132 KB
testcase_13 AC 661 ms
225,640 KB
testcase_14 MLE -
testcase_15 MLE -
testcase_16 AC 736 ms
273,048 KB
testcase_17 AC 857 ms
212,148 KB
testcase_18 AC 858 ms
214,140 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
sys.setrecursionlimit(10**6)

input = raw_input
range = xrange


def read_data():
    N = int(input())
    Es = [[] for i in range(N)]
    for i in range(N - 1):
        u, v = map(int, input().split())
        u -= 1
        v -= 1
        Es[u].append(v)
        Es[v].append(u)
    return N, Es


class Tree():
    def __init__(self, N, Es, root):
        self.n = N
        self.root = root
        self.child = [[] for i in range(N)]
        self._set_child(Es)
    
    def _set_child(self, Es):
        que = [self.root]
        visited = [False] * self.n
        while que:
            v = que.pop()
            for u in Es[v]:
                if visited[u]:
                    continue
                self.child[v].append(u)
                que.append(u)
            visited[v] = True        

class LCArmq():
    def __init__(self, tree):
        D, E, R = self._convert_to_RMQ(tree.child, tree.root, tree.n)
        self._euler = E
        self._reverse = R
        self._RMQ = RMQ(D)

    def _convert_to_RMQ(self, child, root, n):
        ''' LCA の前処理。 RMQ に置き換えるため、Euler tour で巡回して深さのリストをつくる。
        '''
        depth = []
        euler = []
        reverse = [0] * n
        
        def euler_tour(node, d, depth, euler):
            for v in child[node]:
                euler.append(node)
                depth.append(d)
                euler_tour(v, d + 1, depth, euler)
            euler.append(node)
            depth.append(d)

        euler_tour(root, 0, depth, euler)
        for i, node in enumerate(euler):
            reverse[node] = i
        return depth, euler, reverse

    def query(self, v, w):
        i, j = self._reverse[v], self._reverse[w]
        rmq = self._RMQ.query(i, j)
        lca = self._euler[rmq]
        return lca


class RMQ():
    def __init__(self, A):
        self._A = A
        self._preprocess()

    def _preprocess(self):
        ''' RMQ の前処理。
        '''
        n = len(self._A)
        max_j = int(math.log(n, 2))
        self._M = [list(range(n))]
        for j in range(0, max_j):
            shift = 1 << j
            Mj = self._M[j]
            Mjnext = []
            for k1, k2 in zip(Mj, Mj[shift:]):
                k = k1 if self._A[k1] < self._A[k2] else k2
                Mjnext.append(k)
            self._M.append(Mjnext)

    def query(self, i, j):
        if i == j: return i
        if i > j: i, j = j, i
        el = int(math.log(j - i, 2))
        k1 = self._M[el][i]
        k2 = self._M[el][j - (1 << el) + 1]
        rmq = k1 if self._A[k1] < self._A[k2] else k2
        return rmq


def solve(N, Es):
    global imos1, imos2, imos3
    tree = Tree(N, Es, 0)
    lca_rmq = LCArmq(tree)
    imos1 = [0] * N
    imos2 = [0] * N
    Q = int(input())
    for q in range(Q):
        a, b = map(int, input().split())
        a -= 1
        b -= 1
        v = lca_rmq.query(a, b)
        imos1[a] += 1
        imos1[b] += 1
        imos1[v] -= 2
        imos2[v] += 1
    imos3 = [0] * N
    dfs(0, tree.child)
    return count_tax(imos3, imos2, N)

def dfs(u, child):
    global imos3
    ret = imos1[u]
    for v in child[u]:
        ret += dfs(v, child)
    imos3[u] = ret
    return ret

def count_tax(imos3, imos2, N):
    tax = 0
    for i in range(N):
        t = imos3[i] + imos2[i]
        tax += t * (t + 1)
    return tax // 2


N, Es = read_data()
print(solve(N, Es))
0