結果

問題 No.399 動的な領主
ユーザー kept1994kept1994
提出日時 2023-03-21 01:14:51
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 992 ms / 2,000 ms
コード長 8,530 bytes
コンパイル時間 229 ms
コンパイル使用メモリ 81,912 KB
実行使用メモリ 188,160 KB
最終ジャッジ日時 2024-09-18 14:12:59
合計ジャッジ時間 11,076 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 64 ms
67,584 KB
testcase_01 AC 67 ms
67,968 KB
testcase_02 AC 71 ms
70,528 KB
testcase_03 AC 71 ms
70,272 KB
testcase_04 AC 130 ms
79,196 KB
testcase_05 AC 228 ms
83,828 KB
testcase_06 AC 955 ms
112,840 KB
testcase_07 AC 925 ms
110,368 KB
testcase_08 AC 906 ms
112,924 KB
testcase_09 AC 945 ms
113,664 KB
testcase_10 AC 133 ms
79,360 KB
testcase_11 AC 187 ms
83,172 KB
testcase_12 AC 738 ms
119,712 KB
testcase_13 AC 713 ms
120,008 KB
testcase_14 AC 486 ms
188,032 KB
testcase_15 AC 495 ms
188,160 KB
testcase_16 AC 587 ms
158,336 KB
testcase_17 AC 992 ms
111,232 KB
testcase_18 AC 918 ms
112,848 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#!/usr/bin/env python3
from typing import List, Tuple
from collections import deque
import sys
sys.setrecursionlimit(10 ** 9)

class BIT:
    def __init__(self, N):
        self.N = N
        self.bit = [0] * (self.N + 1) # 1-indexedのため
        
    def add(self, pos, val):
        '''Add
            O(logN)
            posは0-index。内部で1-indexedに変換される。
            A[pos] += val 
        '''
        i = pos + 1 # convert from 0-index to 1-index
        while i <= self.N:
            self.bit[i] += val
            i += i & -i

    def deleteNonNegative(self, pos, val) -> int:
        '''Add
            O(logN)
            ※ multisetで使用される関数
            posは0-index。内部で1-indexedに変換される。
            すでにMultiSetに含まれている個数以上は削除されない。
            A[pos] -= val 
        '''
        actualSubstractVal = min(val, self.sum(pos) - self.sum(pos - 1)) # pos - 1は負になってもself.sum()は大丈夫
        i = pos + 1 # convert from 0-index to 1-index
        while i <= self.N:
            self.bit[i] -= actualSubstractVal
            i += i & -i
        return actualSubstractVal

    def sum(self, pos):
        ''' Sum
            0からposまでの和を返す(posを含む)
            O(logN)
            posは0-index。内部で1-indexedに変換される。
            Return Sum(A[0], ... , A[pos])
            posに負の値を指定されるとSum()すなわち0を返すのでマイナスの特段の考慮不要。
        '''
        res = 0
        i = pos + 1 # convert from 0-index to 1-index
        while i > 0:
            res += self.bit[i]
            i -= i & -i    
        return res
    
    def lowerLeft(self, w):
        '''
        O(logN)
        A0 ~ Aiの和がw以上となる最小のindex(値)を返す。
        Ai ≧ 0であること。
        '''
        if (w < 0):
            return 0
        total = self.sum(self.N - 1)
        if w > total:
            return -1
        x = 0
        k = 1 << (self.N.bit_length() - 1)
        while k > 0:
            if x + k < self.N and self.bit[x + k] < w:
                w -= self.bit[x + k]
                x += k
            k //= 2
        return x
        
    def __str__(self):
        '''
        index0は不使用なので表示しない。
        '''
        return "[" + ", ".join(f'{v}' for v in self.bit[1:]) + "]"
    
class HLD:
    def __init__(self, N, G, root: int = 0):
        self.N = N                  # ノード数
        self.G = G                  # グラフ(隣接リスト)
        self.root = root            # 根
        self.depth_of_node = [0] * self.N                 # 各ノードの根からの距離(深さ)
        self.next_heavy_nodes = [-1] * self.N             # next_heavy_nodes[i] := heavy-pathにおいてノードiの次ノード番号 (0-indexed)
        self.parent_nodes = [None] * self.N               # parent_nodes[i] := ノードiの親ノード番号 (0-indexed)
        self.partial_tree_size = [None] * self.N          # partial_tree_size[i] := ノードi以下の部分木のサイズ

        self.heavy_paths = []                            # 全heavy_pathのリスト
        self.depth_of_heavy_path = []                    # depth_of_heavy_path[i] := i番目のheavy_pathの深さ(幾つのheavy_pathを乗り継いで到達するか)
        # Depth 0: [0] -------------------- [2] - [5] - [6] - [8]
        #           |                        |     |
        # Depth 1: [1] - [3] - [9] - [10]   [4]   [7] 
        #                       |
        # Depth 2:             [11]
        self.to_heavy_path_index = [0] * N               # ノードiはどのheavy_pathに所属するか(heavy_pathsのインデックス)
        self.indices_in_heavy_path = [0] * N             # ノードiがその属するheavy-pathの何番目にあたるのか。
        self.heads = [None] * self.N                     # ノードiが属するheavy-pathの先頭のノード番号
        self.ord = [None] * self.N                       # heavy-pathに割り当てた連続する番号
        # 上の例において
        # [0] -------------------- [5] - [6] - [7] - [8]
        #  |                        |     |
        # [1] - [2] - [3] - [4]    [10]  [11] 
        #              |
        #             [9]

    def path_query_range(self, a: int, b: int) -> List[Tuple[int, int]]:
        """
        ノードaとノードbの間に含まれるordの範囲[l, r)のリストを返す。
        """
        ret = []
        while True:
            if self.ord[a] > self.ord[b]:
                a, b = b, a
            if self.heads[a] == self.heads[b]:
                ret.append((self.ord[a], self.ord[b] + 1))
                return ret
            ret.append((self.ord[self.heads[b]], self.ord[b] + 1))
            b = self.parent_nodes[self.heads[b]]

    def subtree_query_range(self, a: int) -> Tuple[int, int]:
        """
        ノードaの部分木に含まれるordの範囲[l, r)のリストを返す。
        
        return [l, r) range that cover vertices of subtree v"""
        return (self.ord[a], self.ord[a] + self.partial_tree_size[a])

    def lca(self, u, v):
        while True:
            if self.ord[u] > self.ord[v]:
                u, v = v, u
            if self.heads[u] == self.heads[v]:
                return u
            v = self.parent_nodes[self.heads[v]]

    def _dfs(self, cur, depth):
        self.depth_of_node[cur] = depth
        sub_node_count = 1              # そのノード(を含む)配下の子ノード数
        heavy_path = None               # そのノードからのheavy-path
        max_heavy_path_length = 0       # 最大のheavy-path長
        for next in self.G[cur]:
            if self.parent_nodes[cur] == next:
                continue
            self.parent_nodes[next] = cur
            sub_node_total_count = self._dfs(next, depth + 1)
            # より長いパスが見つかったなら
            # heavy-pathとしてその子ノードの番号を仮置きする。
            if max_heavy_path_length < sub_node_total_count:
                heavy_path = next
                max_heavy_path_length = sub_node_total_count
            sub_node_count += sub_node_total_count
        self.partial_tree_size[cur] = sub_node_count
        self.next_heavy_nodes[cur] = heavy_path
        return sub_node_count

    def build(self):
        """
        グラフGをheavy-pathに沿って縮約
        """
        self._dfs(self.root, 0)
        stack = deque([(0, 0)])
        order = 0
        # DFS
        while stack:
            cur, depth_of_cur_heaby_path = stack.pop()
            head_of_heavy_path = cur
            cur_heavy_path = []
            k = len(self.heavy_paths)
            while cur is not None:
                self.ord[cur] = order
                self.heads[cur] = head_of_heavy_path
                self.indices_in_heavy_path[cur] = len(cur_heavy_path)
                cur_heavy_path.append(cur)
                self.to_heavy_path_index[cur] = k
                next_heavy_node = self.next_heavy_nodes[cur]

                # heavy-path以外をキューに入れておき、heavy-path上の次のノードの処理へ。
                for next in self.G[cur]:
                    if self.parent_nodes[cur] == next or next_heavy_node == next:
                        continue
                    stack.append((next, depth_of_cur_heaby_path + 1))
                cur = next_heavy_node
                order += 1
                
            self.heavy_paths.append(cur_heavy_path)
            self.depth_of_heavy_path.append(depth_of_cur_heaby_path)

def main():
    # N = 8
    # M = 7
    # A = [0, 0, 0, 1, 1, 5, 5]
    # B = [1, 2, 3, 4, 5, 6, 7]
    # N = 6
    # M = 5
    # A = [0, 1, 1, 0, 4]
    # B = [1, 2, 3, 4, 5]
    N = int(input())
    G = [[] for _ in range(N)]
    for i in range(N - 1):
        u, v = map(int, input().split())
        u -= 1
        v -= 1
        G[u].append(v)
        G[v].append(u)

    hl = HLD(N, G)
    hl.build()

    bit = BIT(N)
    Q = int(input())
    for _ in range(Q):
        a, b = map(int, input().split())
        a -= 1
        b -= 1
        for l, r in hl.path_query_range(a, b):
            bit.add(l, 1)
            bit.add(r, -1)
    ans = 0
    for i in range(N):
        cnt = bit.sum(i)
        ans += (1 + cnt) * cnt // 2
    print(ans)

if __name__ == '__main__':
    main()
0