結果

問題 No.399 動的な領主
ユーザー lam6er
提出日時 2025-04-09 21:01:50
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,389 ms / 2,000 ms
コード長 5,069 bytes
コンパイル時間 424 ms
コンパイル使用メモリ 82,968 KB
実行使用メモリ 282,280 KB
最終ジャッジ日時 2025-04-09 21:03:36
合計ジャッジ時間 14,727 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    n = int(input[ptr])
    ptr += 1

    from collections import defaultdict
    adj = defaultdict(list)
    for _ in range(n-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        adj[u].append(v)
        adj[v].append(u)
        ptr += 2

    # HLD implementation
    class HLD:
        def __init__(self, n, adj, root=1):
            self.n = n
            self.adj = adj
            self.root = root
            self.parent = [0] * (n + 1)
            self.depth = [0] * (n + 1)
            self.size = [1] * (n + 1)
            self.head = [0] * (n + 1)  # head of the chain
            self.in_time = [0] * (n + 1)
            self.current_time = 1  # 1-based

            # First DFS to compute parent, depth, size
            self.dfs1(root, 0)
            # Second DFS to decompose the tree
            self.dfs2(root, 0, root)

        def dfs1(self, u, p):
            self.parent[u] = p
            self.depth[u] = self.depth[p] + 1 if p != 0 else 0
            for v in self.adj[u]:
                if v != p:
                    self.dfs1(v, u)
                    self.size[u] += self.size[v]

        def dfs2(self, u, p, h):
            self.head[u] = h
            self.in_time[u] = self.current_time
            self.current_time += 1
            max_size = -1
            heavy = -1
            for v in self.adj[u]:
                if v != p and self.size[v] > max_size:
                    max_size = self.size[v]
                    heavy = v
            if heavy != -1:
                self.dfs2(heavy, u, h)
                for v in self.adj[u]:
                    if v != p and v != heavy:
                        self.dfs2(v, u, v)

        def get_in_time(self, u):
            return self.in_time[u]

        def path_query(self, u, v, callback):
            while True:
                if self.head[u] == self.head[v]:
                    if self.depth[u] > self.depth[v]:
                        u, v = v, u
                    callback(self.in_time[u], self.in_time[v])
                    break
                else:
                    if self.depth[self.head[u]] < self.depth[self.head[v]]:
                        u, v = v, u
                    callback(self.in_time[self.head[u]], self.in_time[u])
                    u = self.parent[self.head[u]]

    hld = HLD(n, adj, root=1)

    # LCA implementation
    class LCA:
        def __init__(self, n, adj, root=1):
            self.LOG = 20
            self.parent = [[-1] * (n + 1) for _ in range(self.LOG)]
            self.depth = [0] * (n + 1)
            self.root = root
            self.dfs(root, -1)
            for k in range(1, self.LOG):
                for v in range(1, n + 1):
                    if self.parent[k-1][v] != -1:
                        self.parent[k][v] = self.parent[k-1][self.parent[k-1][v]]

        def dfs(self, u, p):
            self.parent[0][u] = p
            for v in adj[u]:
                if v != p:
                    self.depth[v] = self.depth[u] + 1
                    self.dfs(v, u)

        def query(self, u, v):
            if self.depth[u] < self.depth[v]:
                u, v = v, u
            # bring u to the depth of v
            for k in range(self.LOG-1, -1, -1):
                if self.parent[k][u] != -1 and self.depth[u] - (1 << k) >= self.depth[v]:
                    u = self.parent[k][u]
            if u == v:
                return u
            for k in range(self.LOG-1, -1, -1):
                if self.parent[k][u] != -1 and self.parent[k][u] != self.parent[k][v]:
                    u = self.parent[k][u]
                    v = self.parent[k][v]
            return self.parent[0][u]

    lca_instance = LCA(n, adj, root=1)

    # Fenwick Tree implementation for range add and point query
    class FenwickTree:
        def __init__(self, size):
            self.n = size
            self.tree = [0] * (self.n + 2)  # 1-based

        def add(self, idx, val):
            while idx <= self.n:
                self.tree[idx] += val
                idx += idx & -idx

        def query(self, idx):
            res = 0
            while idx > 0:
                res += self.tree[idx]
                idx -= idx & -idx
            return res

    fenwick = FenwickTree(n)

    q = int(input[ptr])
    ptr += 1

    for _ in range(q):
        a = int(input[ptr])
        b = int(input[ptr+1])
        ptr +=2

        l = lca_instance.query(a, b)
        def callback(l, r):
            if l > r:
                l, r = r, l
            fenwick.add(l, 1)
            fenwick.add(r+1, -1)
        hld.path_query(a, l, callback)
        hld.path_query(b, l, callback)
        # Subtract 1 for lca
        in_l = hld.get_in_time(l)
        fenwick.add(in_l, -1)
        fenwick.add(in_l +1, 1)

    total = 0
    for v in range(1, n+1):
        in_v = hld.get_in_time(v)
        k = fenwick.query(in_v)
        total += k * (k +1) // 2
    print(total)

if __name__ == '__main__':
    main()
0