結果

問題 No.399 動的な領主
ユーザー tktk_snsntktk_snsn
提出日時 2021-11-03 13:13:40
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 418 ms / 2,000 ms
コード長 3,581 bytes
コンパイル時間 548 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 123,052 KB
最終ジャッジ日時 2024-10-13 02:30:04
合計ジャッジ時間 7,275 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
52,736 KB
testcase_01 AC 38 ms
52,736 KB
testcase_02 AC 41 ms
53,376 KB
testcase_03 AC 41 ms
53,480 KB
testcase_04 AC 99 ms
76,048 KB
testcase_05 AC 182 ms
80,384 KB
testcase_06 AC 393 ms
116,496 KB
testcase_07 AC 407 ms
116,796 KB
testcase_08 AC 403 ms
116,452 KB
testcase_09 AC 401 ms
116,684 KB
testcase_10 AC 105 ms
76,704 KB
testcase_11 AC 177 ms
80,632 KB
testcase_12 AC 362 ms
117,172 KB
testcase_13 AC 361 ms
117,076 KB
testcase_14 AC 218 ms
123,052 KB
testcase_15 AC 232 ms
112,160 KB
testcase_16 AC 249 ms
119,424 KB
testcase_17 AC 392 ms
116,892 KB
testcase_18 AC 418 ms
117,120 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.buffer.readline
sys.setrecursionlimit(10 ** 7)


class SegTree(object):
    def __init__(self, N, op_data, u_data):
        self._n = N
        self.log = (N-1).bit_length()
        self.size = 1 << self.log

        self.op = op_data
        self.e = u_data

        self.data = [u_data] * (2 * self.size)
        # self.len = [1] * (2 * self.size)

    def _update(self, i):
        self.data[i] = self.op(self.data[i << 1], self.data[i << 1 | 1])

    def initialize(self, arr=None):
        """ segtreeをarrで初期化する。len(arr) == Nにすること """
        if arr:
            for i, a in enumerate(arr, self.size):
                self.data[i] = a
        for i in reversed(range(1, self.size)):
            self._update(i)
            # self.len[i] = self.len[i << 1] + self.len[i << 1 | 1]

    def update(self, p, x):
        """ data[p] = x とする (0-indexed)"""
        p += self.size
        self.data[p] = x
        for i in range(1, self.log + 1):
            self._update(p >> i)

    def get(self, p):
        """ data[p]を返す """
        return self.data[p + self.size]

    def prod(self, l, r):
        """
        op_data(data[l], data[l+1], ..., data[r-1])を返す (0-indexed)
        """
        sml = self.e
        smr = self.e
        l += self.size
        r += self.size

        while l < r:
            if l & 1:
                sml = self.op(sml, self.data[l])
                l += 1
            if r & 1:
                r -= 1
                smr = self.op(self.data[r], smr)
            l >>= 1
            r >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        """ op(data[0], data[1], ... data[N-1])を返す """
        return self.data[1]


class LowestCommonAncestor(SegTree):
    def __init__(self, N, root, G):
        self.n = N
        self.depth = [0] * N
        self.tout = [-1] * N
        self.tin = [-1] * N
        self.tin[root] = 0
        euler = [0]
        par = [-1] * N
        itr = [0] * N
        que = [root]
        topo = []
        while que:
            s = que[-1]
            if itr[s] < len(G[s]):
                t = G[s][itr[s]]
                itr[s] += 1
                if t == par[s]:
                    continue
                par[t] = s
                self.depth[t] = self.depth[s] + 1
                self.tin[t] = len(euler)
                euler.append(N * self.depth[t] + t)
                que.append(t)
            else:
                topo.append(s)
                p = par[s]
                self.tout[s] = len(euler)
                euler.append(N * self.depth[p] + p)
                que.pop()
        euler.pop()
        self.par = par
        self.topo = topo

        super().__init__(len(euler), min, N * N + 10)
        self.initialize(euler)

    def __call__(self, a, b):
        """LCA(a, b)を返す"""
        l = min(self.tin[a], self.tin[b])
        r = max(self.tout[a], self.tout[b])
        return self.prod(l, r) % self.n


N = int(input())
G = [[] for _ in range(N)]
for _ in range(N - 1):
    a, b = map(lambda x: int(x) - 1, input().split())
    G[a].append(b)
    G[b].append(a)

"""各頂点を通った回数がわかるとOK"""
LCA = LowestCommonAncestor(N, 0, G)
par = LCA.par
cnt = [0] * N

Q = int(input())
for _ in range(Q):
    x, y = map(lambda x: int(x) - 1, input().split())
    z = LCA(x, y)
    cnt[x] += 1
    cnt[y] += 1
    cnt[z] -= 1
    if z != 0:
        cnt[par[z]] -= 1

for s in LCA.topo[:-1]:
    cnt[par[s]] += cnt[s]

ans = 0
for c in cnt:
    ans += c * (c + 1) // 2
print(ans)
0