結果

問題 No.3348 Tree Balance
コンテスト
ユーザー Nzt3
提出日時 2025-11-13 01:06:16
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 4,172 ms / 5,000 ms
コード長 5,841 bytes
コンパイル時間 344 ms
コンパイル使用メモリ 82,036 KB
実行使用メモリ 277,908 KB
最終ジャッジ日時 2025-11-13 21:25:45
合計ジャッジ時間 35,023 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

sys.setrecursionlimit(200010)


class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.height = 1


class AVLTree:
    def __init__(self):
        self.root = None
        self.size = 0

    def __len__(self):
        return self.size

    def _get_height(self, node):
        if not node:
            return 0
        return node.height

    def _get_balance(self, node):
        if not node:
            return 0
        return self._get_height(node.left) - self._get_height(node.right)

    def _right_rotate(self, y):
        x = y.left
        T2 = x.right

        x.right = y
        y.left = T2

        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))
        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))

        return x

    def _left_rotate(self, x):
        y = x.right
        T2 = y.left

        y.left = x
        x.right = T2

        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))
        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))

        return y

    def insert(self, value):
        self.root = self._insert(self.root, value)

    def _insert(self, root, value):
        if not root:
            self.size += 1
            return Node(value)

        if value < root.value:
            root.left = self._insert(root.left, value)
        elif value > root.value:
            root.right = self._insert(root.right, value)
        else:
            return root

        root.height = 1 + max(self._get_height(root.left),
                              self._get_height(root.right))

        balance = self._get_balance(root)

        if balance > 1 and value < root.left.value:
            return self._right_rotate(root)

        if balance < -1 and value > root.right.value:
            return self._left_rotate(root)

        if balance > 1 and value > root.left.value:
            root.left = self._left_rotate(root.left)
            return self._right_rotate(root)

        if balance < -1 and value < root.right.value:
            root.right = self._right_rotate(root.right)
            return self._left_rotate(root)

        return root

    def find_lower_bound(self, value):
        node = self.root
        best_so_far = None
        while node:
            if node.value == value:
                return value
            elif node.value > value:
                best_so_far = node.value
                node = node.left
            else:
                node = node.right
        return best_so_far

    def find_predecessor(self, value):
        node = self.root
        best_so_far = None
        while node:
            if node.value < value:
                best_so_far = node.value  # 候補として保持
                node = node.right  # より大きい値(ただしvalue未満)を探す
            else:  # node.value >= value
                node = node.left  # より小さい値が必要
        return best_so_far

    def __iter__(self):
        stack = []
        curr = self.root
        while stack or curr:
            while curr:
                stack.append(curr)
                curr = curr.left
            curr = stack.pop()
            yield curr.value
            curr = curr.right

    def merge(self, other_tree):
        for val in other_tree:
            self.insert(val)


def main():
    I = sys.stdin.readline
    O = sys.stdout.write

    n = int(I())
    w = list(map(int, I().split()))
    graph = [[] for _ in range(n)]
    for _ in range(n - 1):
        a, b = map(int, I().split())
        a -= 1
        b -= 1
        graph[a].append(b)
        graph[b].append(a)

    sum_ = list(w)

    stack = [(0, -1, 0)]

    while stack:
        v, p, state = stack.pop()
        if state == 0:
            stack.append((v, p, 1))
            for u in graph[v]:
                if u == p:
                    continue
                stack.append((u, v, 0))
        else:
            # 帰りがけ
            for u in graph[v]:
                if u == p:
                    continue
                sum_[v] += sum_[u]

    total = sum_[0]
    ans = 1 << 60

    def update(s1, s2):
        nonlocal ans
        s3 = total - s1 - s2
        mn = min(s1, s2, s3)
        mx = max(s1, s2, s3)
        ans = min(ans, mx - mn)

    stack = [(0, -1, 0)]
    res_map = {}

    while stack:
        v, p, state = stack.pop()
        if state == 0:
            stack.append((v, p, 1))
            for u in graph[v]:
                if u == p:
                    continue
                stack.append((u, v, 0))
        else:
            sums = AVLTree()
            current = sum_[v]
            remain = total - current

            for u in graph[v]:
                if u == p:
                    continue

                res = res_map[u]
                X1 = current // 2

                itr_lower = res.find_lower_bound(X1)
                if itr_lower is not None:
                    update(remain, itr_lower)

                itr_pred = res.find_predecessor(X1)
                if itr_pred is not None:
                    update(remain, itr_pred)

                if len(sums) < len(res):
                    sums, res = res, sums

                for e in res:
                    X2 = (total - e) // 2

                    itr_lower_2 = sums.find_lower_bound(X2)
                    if itr_lower_2 is not None:
                        update(e, itr_lower_2)

                    itr_pred_2 = sums.find_predecessor(X2)
                    if itr_pred_2 is not None:
                        update(e, itr_pred_2)

                sums.merge(res)

            sums.insert(current)
            res_map[v] = sums

    O(f"{ans}\n")


if __name__ == "__main__":
    main()
0