結果

問題 No.3348 Tree Balance
コンテスト
ユーザー ZOI-dayo
提出日時 2025-10-29 18:49:41
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
TLE  
実行時間 -
コード長 5,399 bytes
コンパイル時間 317 ms
コンパイル使用メモリ 12,928 KB
実行使用メモリ 145,464 KB
最終ジャッジ日時 2025-11-13 21:07:23
合計ジャッジ時間 13,184 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

sys.setrecursionlimit(200010)
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.height = 1

class AVLTree:
    def __init__(self):
        self.root = None
        self.size = 0

    def __len__(self):
        return self.size

    def _get_height(self, node):
        if not node:
            return 0
        return node.height

    def _get_balance(self, node):
        if not node:
            return 0
        return self._get_height(node.left) - self._get_height(node.right)

    def _right_rotate(self, y):
        x = y.left
        T2 = x.right
        
        x.right = y
        y.left = T2
        
        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))
        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))
        
        return x

    def _left_rotate(self, x):
        y = x.right
        T2 = y.left
        
        y.left = x
        x.right = T2
        
        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))
        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))
        
        return y

    def insert(self, value):
        self.root = self._insert(self.root, value)

    def _insert(self, root, value):
        if not root:
            self.size += 1
            return Node(value)
        
        if value < root.value:
            root.left = self._insert(root.left, value)
        elif value > root.value:
            root.right = self._insert(root.right, value)
        else:
            return root

        root.height = 1 + max(self._get_height(root.left), self._get_height(root.right))

        balance = self._get_balance(root)

        if balance > 1 and value < root.left.value:
            return self._right_rotate(root)
        
        if balance < -1 and value > root.right.value:
            return self._left_rotate(root)
        
        if balance > 1 and value > root.left.value:
            root.left = self._left_rotate(root.left)
            return self._right_rotate(root)
        
        if balance < -1 and value < root.right.value:
            root.right = self._right_rotate(root.right)
            return self._left_rotate(root)
        
        return root

    def find_lower_bound(self, value):
        node = self.root
        best_so_far = None
        while node:
            if node.value == value:
                return value
            elif node.value > value:
                best_so_far = node.value
                node = node.left
            else:
                node = node.right
        return best_so_far

    def find_predecessor(self, value):
        node = self.root
        best_so_far = None
        while node:
            if node.value < value:
                best_so_far = node.value  # 候補として保持
                node = node.right  # より大きい値(ただしvalue未満)を探す
            else:  # node.value >= value
                node = node.left  # より小さい値が必要
        return best_so_far

    def __iter__(self):
        stack = []
        curr = self.root
        while stack or curr:
            while curr:
                stack.append(curr)
                curr = curr.left
            curr = stack.pop()
            yield curr.value
            curr = curr.right

    def merge(self, other_tree):
        for val in other_tree:
            self.insert(val)

def main():
    I = sys.stdin.readline
    O = sys.stdout.write

    n = int(I())
    w = list(map(int, I().split()))
    graph = [[] for _ in range(n)]
    for _ in range(n - 1):
        a, b = map(int, I().split())
        a -= 1
        b -= 1
        graph[a].append(b)
        graph[b].append(a)

    sum_ = list(w)
    def dfs1(i, p):
        for c in graph[i]:
            if c == p:
                continue
            sum_[i] += dfs1(c, i)
        return sum_[i]

    dfs1(0, -1)
    total = sum_[0]

    ans = 1 << 60

    def update(s1, s2):
        nonlocal ans
        s3 = total - s1 - s2
        
        mn = min(s1, s2, s3)
        mx = max(s1, s2, s3)
        ans = min(ans, mx - mn)

    def dfs2(i, p):
        sums = AVLTree()
        current = sum_[i]
        remain = total - current
        
        for c in graph[i]:
            if c == p:
                continue
            
            res = dfs2(c, i)
            X1 = current // 2
            
            itr_lower = res.find_lower_bound(X1)
            if itr_lower is not None:
                update(remain, itr_lower)
            
            itr_pred = res.find_predecessor(X1)
            if itr_pred is not None:
                update(remain, itr_pred)

            if len(sums) < len(res):
                sums, res = res, sums

            for e in res:
                X2 = (total - e) // 2
                
                # 1. find_lower_bound(X2) (s2)
                itr_lower_2 = sums.find_lower_bound(X2)
                if itr_lower_2 is not None:
                    update(e, itr_lower_2)
                
                itr_pred_2 = sums.find_predecessor(X2)
                if itr_pred_2 is not None:
                    update(e, itr_pred_2)
            sums.merge(res)

        sums.insert(current)
        return sums

    dfs2(0, -1)
    
    O(f"{ans}\n")

if __name__ == "__main__":
    main()
0