結果

問題 No.3348 Tree Balance
コンテスト
ユーザー ZOI-dayo
提出日時 2025-10-25 16:32:39
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 5,265 bytes
コンパイル時間 349 ms
コンパイル使用メモリ 82,180 KB
実行使用メモリ 624,720 KB
最終ジャッジ日時 2025-11-13 20:54:27
合計ジャッジ時間 13,172 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other MLE * 1 -- * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from bisect import bisect_left

def solve():
    # 再帰深度を深く設定
    # N が大きい場合 (例: 2*10^5) に備える
    sys.setrecursionlimit(200010)

    try:
        N = int(sys.stdin.readline())
        W = list(map(int, sys.stdin.readline().split()))
        
        adj = [[] for _ in range(N)]
        for _ in range(N - 1):
            A, B = map(int, sys.stdin.readline().split())
            u, v = A - 1, B - 1
            adj[u].append(v)
            adj[v].append(u)

    except EOFError:
        return
    except Exception as e:
        return

    if N < 3:
        print(0)
        return

    total_sum = sum(W)
    subtree_sum = [0] * N
    
    # グローバル変数として最小差を保持
    # (リストやクラス変数にしてミュータブルに渡す)
    min_diff = [float('inf')]

    # --- 1. 事前計算: 部分木の重み総和 (DFS) ---
    def dfs_precompute(v, p):
        s = W[v]
        for u in adj[v]:
            if u == p:
                continue
            dfs_precompute(u, v)
            s += subtree_sum[u]
        subtree_sum[v] = s

    dfs_precompute(0, -1)

    # --- 最小差を更新するヘルパー関数 ---
    def check_min_diff(S1, S2, S3):
        if S1 <= 0 or S2 <= 0 or S3 <= 0:
             # 有効な3分割ではない (重みが0のケースなど)
             # ただし、問題の制約 (W_i >= 1) があれば不要
             return
        diff = max(S1, S2, S3) - min(S1, S2, S3)
        if diff < min_diff[0]:
            min_diff[0] = diff

    # --- 2. メインのDFS (マージテクニック) ---
    # v: 現在の頂点, p: 親
    # 戻り値: v の部分木内で可能な「切断点の重み」のソート済みリスト
    def dfs_solve(v, p):
        
        # my_desc_sums: v の子孫にあたる切断点の重み (S(w)) のソート済みリスト
        my_desc_sums = []

        for u in adj[v]:
            if u == p:
                continue
            
            # 1. 子からソート済みリストを受け取る
            child_desc_sums = dfs_solve(u, v)
            
            # --- 2. Case 3 (並列関係) の処理 ---
            # child_desc_sums (子uの内部) から1つ (S_A)
            # my_desc_sums (他の子) から1つ (S_B) を選ぶ
            
            # O(N log^2 N) にするため、小さい方からイテレート (Small-to-Large)
            list_A = child_desc_sums
            list_B = my_desc_sums
            if len(list_A) > len(list_B):
                list_A, list_B = list_B, list_A
            
            # O(min(|A|,|B|) * log(max(|A|,|B|)))
            # for S_A in list_A:
            #     target = (total_sum - S_A) / 2
            #     idx = bisect_left(list_B, target)
            #     for k in [idx, idx - 1]:
            #         if 0 <= k < len(list_B):
            #             S_B = list_B[k]
            #             S_C = total_sum - S_A - S_B
            #             check_min_diff(S_A, S_B, S_C)

            # O(N log N) のための 尺取り法 O(|A| + |B|)
            # list_A (S_A) は昇順, target = (T-S_A)/2 は降順
            # list_B (S_B) で target に近い点を探すポインタ ptr_B も降順に動く
            if list_A and list_B:
                ptr_b = len(list_B) - 1
                for S_A in list_A:
                    target = (total_sum - S_A) / 2
                    # target より大きい S_B をスキップ
                    while ptr_b > 0 and list_B[ptr_b] > target:
                        ptr_b -= 1
                    
                    # list_B[ptr_b] (<= target) と list_B[ptr_b+1] (> target) が候補
                    for k in [ptr_b, ptr_b + 1]:
                        if 0 <= k < len(list_B):
                            S_B = list_B[k]
                            S_C = total_sum - S_A - S_B
                            check_min_diff(S_A, S_B, S_C)
            
            # 3. リストのマージ (O(|A| + |B|))
            my_desc_sums = list(heapq.merge(my_desc_sums, child_desc_sums))

        # --- 4. Case 1 (祖先-子孫関係) の処理 ---
        # 1つ目の切断点を v (重み S_v)
        # 2つ目の切断点を my_desc_sums の中から (重み S_B)
        S_v = subtree_sum[v]
        S_rest = total_sum - S_v
        
        # S_B と S_v - S_B が近くなる S_B (S_B approx S_v / 2) を探す
        target = S_v / 2
        idx = bisect_left(my_desc_sums, target)
        for k in [idx, idx - 1]:
            if 0 <= k < len(my_desc_sums):
                S_B = my_desc_sums[k]
                S_A = S_v - S_B
                check_min_diff(S_A, S_B, S_rest)

        # --- 5. 親にリストを渡す ---
        # v自身も切断点になりうる (vが根でなければ)
        if p != -1:
            # S(v) をソートを保ったまま追加 (O(N_v))
            # O(N log N) にするためには、最後にマージする
            # (今回は heapq.merge を使うので O(N_v))
            my_desc_sums = list(heapq.merge(my_desc_sums, [S_v]))
            
        return my_desc_sums

    # 根 (頂点0) から探索開始
    dfs_solve(0, -1)

    print(min_diff[0])

if __name__ == "__main__":
    solve()
0