結果

問題 No.399 動的な領主
コンテスト
ユーザー norioc
提出日時 2025-11-12 01:56:57
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 779 ms / 2,000 ms
コード長 8,419 bytes
コンパイル時間 309 ms
コンパイル使用メモリ 82,252 KB
実行使用メモリ 109,500 KB
最終ジャッジ日時 2025-11-12 01:57:10
合計ジャッジ時間 10,708 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections.abc import Callable


class HLD:
    def __init__(self, n: int, adj: dict[int, list[int]], root=0):
        """
        n: 頂点数
        adj: {頂点: [隣接頂点, ...]}
        root: 根
        """
        self.n = n
        self.vid = [-1] * n
        self.inv = [0] * n
        self.par = [-1] * n
        self.depth = [0] * n
        self.subsize = [1] * n
        self.head = [0] * n
        self.prev = [-1] * n
        self.next = [-1] * n
        self.types = [0] * n

        self._build([root], adj)

    def _build(self, roots, adj):
        pos = 0
        for i, root in enumerate(roots):
            self._decide_heavy_edge(root, adj)
            pos = self._reconstruct(root, i, pos, adj)

    def _decide_heavy_edge(self, root: int, adj):
        """部分木サイズを計算し、heavy edge を決定"""
        st = [(root, 0)]
        self.par[root] = -1
        self.depth[root] = 0
        while st:
            v, i = st[-1]
            if i < len(adj[v]):
                to = adj[v][i]
                st[-1] = (v, i+1)
                if to == self.par[v]: continue
                self.par[to] = v
                self.depth[to] = self.depth[v] + 1
                st.append((to, 0))
            else:
                st.pop()
                maxsize = 0
                for to in adj[v]:
                    if to == self.par[v]: continue
                    self.subsize[v] += self.subsize[to]
                    if self.subsize[to] > maxsize:
                        maxsize = self.subsize[to]
                        self.prev[to] = v
                        self.next[v] = to

    def _reconstruct(self, root: int, curtype: int, pos: int, adj):
        """heavy-pathごとに Euler順 vid を割り当てる"""
        st = [root]
        while st:
            start = st.pop()
            v = start
            while v != -1:
                self.types[v] = curtype
                self.vid[v] = pos
                self.inv[pos] = v
                self.head[v] = start
                pos += 1
                # 軽辺の子はあとでstackに積む
                for to in adj[v]:
                    if to != self.par[v] and to != self.next[v]:
                        st.append(to)
                v = self.next[v]

        return pos

    def foreach_nodes(self, u: int, v: int, f: Callable[[int, int], None]):
        """頂点 u, v 間の頂点区間に対してコールバック [l, r] を呼ぶ"""
        while True:
            if self.vid[u] > self.vid[v]:
                u, v = v, u

            f(max(self.vid[self.head[v]], self.vid[u]), self.vid[v])
            if self.head[u] != self.head[v]:
                v = self.par[self.head[v]]
            else:
                break

    # u-v間の辺区間に対してコールバックを呼ぶ
    def foreach_edges(self, u: int, v: int, f: Callable[[int, int], None]):
        """頂点 u, v 間の辺区間に対してコールバック [l, r] を呼ぶ"""
        while True:
            if self.vid[u] > self.vid[v]:
                u, v = v, u

            if self.head[u] != self.head[v]:
                f(self.vid[self.head[v]], self.vid[v])
                v = self.par[self.head[v]]
            else:
                if u != v:
                    f(self.vid[u]+1, self.vid[v])
                break

    def lca(self, u: int, v: int) -> int:
        while True:
            if self.vid[u] > self.vid[v]:
                u, v = v, u
            if self.head[u] == self.head[v]:
                return u
            v = self.par[self.head[v]]


class BucketArray:
    B = 512  # 一つのバケット幅

    def __init__(self, a: list[int]):
        assert len(a) > 0
        n = len(a)
        bsize = (n + self.B - 1) // self.B
        self.bsize = bsize
        self.has_lazy_set = [False] * bsize  # 遅延伝搬フラグ
        self.lazy_set_buckets = [0] * bsize  # バケットへの変更 (遅延伝搬)
        self.add_buckets = [0] * bsize  # バケット全体への加算
        self.sum_buckets = [0] * bsize  # バケットの和
        self.data = a.copy()
        for i in range(bsize):
            start = i * self.B
            stop = min(n, (i+1) * self.B)
            self.sum_buckets[i] = sum([a[i] for i in range(start, stop)])

    def __iter__(self):
        for i in range(self.bsize):
            self._force(i)
        yield from self.data

    def __len__(self):
        return len(self.data)

    def _force(self, k: int):
        """遅延伝搬"""
        start = k * self.B
        stop = min((k+1) * self.B, len(self.data))

        if self.has_lazy_set[k]:
            self.has_lazy_set[k] = False
            v = self.lazy_set_buckets[k]
            for i in range(start, stop):
                self.data[i] = v
            self.sum_buckets[k] = v * (stop - start)

        v = self.add_buckets[k]
        if v != 0:
            for i in range(start, stop):
                self.data[i] += v
            self.sum_buckets[k] += v * (stop - start)
            self.add_buckets[k] = 0

    def range_set(self, l: int, r: int, x: int):
        """[l, r) を x に変更"""
        assert 0 <= l < r <= len(self.data)
        for k in range(l // self.B, self.bsize):
            nl = k * self.B
            nr = (k+1) * self.B
            if r <= nl: break
            if l <= nl and nr <= r:
                self.has_lazy_set[k] = True
                self.lazy_set_buckets[k] = x
                self.add_buckets[k] = 0
                self.sum_buckets[k] = x * self.B
            else:
                self._force(k)
                start = max(l, nl)
                stop = min(r, nr, len(self.data))
                s = 0
                for i in range(start, stop):
                    s += x - self.data[i]
                    self.data[i] = x
                self.sum_buckets[k] += s

    def range_add(self, l: int, r: int, x: int):
        """[l, r) に x を加算"""
        assert 0 <= l < r <= len(self.data)
        for k in range(l // self.B, self.bsize):
            nl = k * self.B
            nr = (k+1) * self.B
            if r <= nl: break
            if l <= nl and nr <= r:
                self.add_buckets[k] += x
            else:
                self._force(k)
                start = max(l, nl)
                stop = min(r, nr, len(self.data))
                for i in range(start, stop):
                    self.data[i] += x
                self.sum_buckets[k] += x * (stop - start)

    def range_sum(self, l, r):
        """[l, r) の和"""
        assert 0 <= l < r <= len(self.data)
        res = 0
        for k in range(l // self.B, self.bsize):
            nl = k * self.B
            nr = (k+1) * self.B
            if r <= nl: break
            if l <= nl and nr <= r:
                res += self.sum_buckets[k] + self.add_buckets[k] * self.B
            else:
                self._force(k)
                start = max(l, nl)
                stop = min(r, nr, len(self.data))
                for i in range(start, stop):
                    res += self.data[i]

        return res


class VertexWeightedPathAddHLD:
    def __init__(self, n: int, adj: dict[int, list[int]]):
        """
        n: 頂点数
        adj: {頂点: [隣接頂点, ...]}
        """

        self.n = n
        self.hld = HLD(n, adj, root=0)
        self.ba = BucketArray([0] * n)

    def add_weight(self, v: int, u: int, weight: int):
        """頂点間 u, v のパスの全ての頂点に重み weight を加算する"""
        def f(l, r):
            self.ba.range_add(l, r+1, weight)

        self.hld.foreach_nodes(u, v, f)

    def root_sum(self, v: int) -> int:
        """頂点 v を根とする部分木の全ての頂点の weight の総和を求める"""
        l = self.hld.vid[v]
        r = self.hld.vid[v] + self.hld.subsize[v] - 1
        if l > r: return 0
        return self.ba.range_sum(l, r+1)

    def get_all_weights(self):
        return list(self.ba)


from collections import defaultdict
from math import comb

N = int(input())
adj = defaultdict(list)
for _ in range(N-1):
    u, v = map(lambda x: int(x)-1, input().split())
    adj[u].append(v)
    adj[v].append(u)

Q = int(input())
hld = VertexWeightedPathAddHLD(N, adj)
for _ in range(Q):
    A, B = map(lambda x: int(x)-1, input().split())
    hld.add_weight(A, B, 1)

ans = sum([comb(x+1, 2) for x in hld.get_all_weights()])
print(ans)
0