結果

問題 No.399 動的な領主
ユーザー brthyyjpbrthyyjp
提出日時 2020-12-23 10:29:36
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 528 ms / 2,000 ms
コード長 3,722 bytes
コンパイル時間 244 ms
コンパイル使用メモリ 82,008 KB
実行使用メモリ 138,396 KB
最終ジャッジ日時 2024-09-21 16:17:55
合計ジャッジ時間 8,272 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 41 ms
53,380 KB
testcase_01 AC 39 ms
54,464 KB
testcase_02 AC 41 ms
54,784 KB
testcase_03 AC 43 ms
54,228 KB
testcase_04 AC 90 ms
76,240 KB
testcase_05 AC 201 ms
81,932 KB
testcase_06 AC 528 ms
120,904 KB
testcase_07 AC 504 ms
122,504 KB
testcase_08 AC 501 ms
118,736 KB
testcase_09 AC 491 ms
116,684 KB
testcase_10 AC 106 ms
76,992 KB
testcase_11 AC 192 ms
81,896 KB
testcase_12 AC 436 ms
120,908 KB
testcase_13 AC 442 ms
120,908 KB
testcase_14 AC 224 ms
137,792 KB
testcase_15 AC 251 ms
138,396 KB
testcase_16 AC 275 ms
123,384 KB
testcase_17 AC 500 ms
118,752 KB
testcase_18 AC 504 ms
119,084 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class SegTree:
    def __init__(self, init_val, ide_ele, segfunc):
        self.n = len(init_val)
        self.num = 2**(self.n-1).bit_length()
        self.ide_ele = ide_ele
        self.segfunc = segfunc
        self.seg = [ide_ele]*2*self.num
        # set_val
        for i in range(self.n):
            self.seg[i+self.num] = init_val[i]
        # built
        for i in range(self.num-1, 0, -1):
            self.seg[i] = self.segfunc(self.seg[2*i], self.seg[2*i+1])

    def update(self, k, x):
        k += self.num
        self.seg[k] = x
        while k:
            k = k >> 1
            self.seg[k] = self.segfunc(self.seg[2*k], self.seg[2*k+1])

    def query(self, l, r):
        if r <= l:
            return self.ide_ele
        l += self.num
        r += self.num
        res = self.ide_ele
        while l < r:
            if r & 1:
                r -= 1
                res = self.segfunc(res, self.seg[r])
            if l & 1:
                res = self.segfunc(res, self.seg[l])
                l += 1
            l = l >> 1
            r = r >> 1
        return res

def segfunc(x, y):
    if x <= y:
        return x
    else:
        return y

ide_ele = 10**18

class LCA:
    def __init__(self, g, root):
        # g: adjacency list
        # root
        self.n = len(g)
        self.root = root

        s = [self.root]
        self.parent = [-1]*self.n
        self.child = [[] for _ in range(self.n)]
        self.order = []
        visit = [-1]*self.n
        visit[self.root] = 0
        while s:
            v = s.pop()
            self.order.append(v)
            for u in g[v]:
                if visit[u] == -1:
                    self.parent[u] = v
                    self.child[v].append(u)
                    visit[u] = 0
                    s.append(u)
        self.order.reverse()

        # Euler tour
        tank = [self.root]
        self.eulerTour = []
        self.left = [0]*self.n
        self.right = [-1]*self.n
        self.depth = [-1]*self.n

        eulerNum = -1
        de = -1

        while tank:
            v = tank.pop()
            if v >= 0:
                eulerNum += 1
                self.eulerTour.append(v)
                self.left[v] = eulerNum
                self.right[v] = eulerNum
                tank.append(~v)
                de += 1
                self.depth[v] = de
                for u in self.child[v]:
                    tank.append(u)
            else:
                de -= 1
                if ~v != self.root:
                    self.eulerTour.append(self.parent[~v])
                    eulerNum += 1
                    self.right[self.parent[~v]] = eulerNum

        #A = [self.depth[e] for e in self.eulerTour]
        A = [0]*(2*self.n-1)
        for i, e in enumerate(self.eulerTour):
            A[i] = self.depth[e]*(2*self.n-1)+i
        self.seg = SegTree(A, ide_ele, segfunc)

    def getLCA(self, u, v):
        # u, v: 0-indexed
        p = min(self.left[u], self.left[v])
        q = max(self.right[u], self.left[v])+1
        m = self.seg.query(p, q)
        return self.eulerTour[m%(2*self.n-1)]

import sys
import io, os
input = io.BytesIO(os.read(0,os.fstat(0).st_size)).readline

n = int(input())
g = [[] for i in range(n)]
for i in range(n-1):
    u, v = map(int, input().split())
    u, v = u-1, v-1
    g[u].append(v)
    g[v].append(u)

lca = LCA(g, 0)
order = lca.order
parent = lca.parent

C = [0]*(n+1)
q = int(input())
for i in range(q):
    a,b = map(int, input().split())
    a,b = a-1, b-1
    l = lca.getLCA(a,b)
    C[a] += 1
    C[b] += 1
    C[l] -= 1
    C[parent[l]] -= 1

for v in order:
    C[parent[v]] += C[v]

ans = 0
for i in range(n):
    c = C[i]
    ans += c*(c+1)//2
print(ans)
0