結果

問題 No.3348 Tree Balance
コンテスト
ユーザー ZOI-dayo
提出日時 2025-10-25 16:36:43
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
TLE  
実行時間 -
コード長 4,330 bytes
コンパイル時間 427 ms
コンパイル使用メモリ 12,288 KB
実行使用メモリ 97,932 KB
最終ジャッジ日時 2025-11-13 20:55:25
合計ジャッジ時間 13,005 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

# By Gemini

import sys
import heapq
from bisect import bisect_left

# 再帰深度を深く設定
sys.setrecursionlimit(200010)

def solve():
    try:
        N = int(sys.stdin.readline())
        W = list(map(int, sys.stdin.readline().split()))
        
        adj = [[] for _ in range(N)]
        for _ in range(N - 1):
            A, B = map(int, sys.stdin.readline().split())
            u, v = A - 1, B - 1
            adj[u].append(v)
            adj[v].append(u)

    except EOFError:
        return
    except Exception as e:
        return

    if N < 3:
        print(0)
        return

    total_sum = sum(W)
    subtree_sum = [0] * N
    
    # 最小差 (ミュータブルなリストとして渡す)
    min_diff = [float('inf')]

    # --- 1. 事前計算: 部分木の重み総和 (O(N)) ---
    def dfs_precompute(v, p):
        s = W[v]
        for u in adj[v]:
            if u == p:
                continue
            dfs_precompute(u, v)
            s += subtree_sum[u]
        subtree_sum[v] = s
    
    dfs_precompute(0, -1)

    # --- 最小差を更新するヘルパー関数 ---
    def check_min_diff(S1, S2, S3):
        # W_i >= 1 が保証されていれば S_i > 0 は不要
        if S1 <= 0 or S2 <= 0 or S3 <= 0:
             return
        diff = max(S1, S2, S3) - min(S1, S2, S3)
        if diff < min_diff[0]:
            min_diff[0] = diff

    # --- 2. メインのDFS (O(N log^2 N)) ---
    # 戻り値: v の部分木内で可能な「切断点の重み」のソート済みリスト
    def dfs_solve(v, p):
        
        # --- 1. 子からリストを収集 ---
        child_lists = []
        for u in adj[v]:
            if u == p:
                continue
            child_lists.append(dfs_solve(u, v))
        
        # --- 2. Small-to-Large (マージテク) の準備 ---
        if not child_lists:
            # v が葉の場合
            if p != -1: # vが根(N=1)でなければ
                # 自身の重みをリストに入れて返す
                return [subtree_sum[v]]
            else:
                return []
        
        # リストをサイズ順にソート
        child_lists.sort(key=len)
        
        # 最大のリストを my_desc_sums (ベース) にする
        my_desc_sums = child_lists.pop()
        
        # --- 3. 小さいリストを大きいリストにマージ (Case 3 処理) ---
        for list_A in child_lists: # list_A は list_B (my_desc_sums) より小さい
            list_B = my_desc_sums
            
            # Case 3 (並列関係) の処理 (尺取り法: O(|A| + |B|))
            if list_A and list_B:
                ptr_b = len(list_B) - 1
                for S_A in list_A:
                    target = (total_sum - S_A) / 2
                    while ptr_b > 0 and list_B[ptr_b] > target:
                        ptr_b -= 1
                    
                    for k in [ptr_b, ptr_b + 1]:
                        if 0 <= k < len(list_B):
                            S_B = list_B[k]
                            S_C = total_sum - S_A - S_B
                            check_min_diff(S_A, S_B, S_C)
            
            # リストのマージ (O(|A| + |B|))
            my_desc_sums = list(heapq.merge(my_desc_sums, list_A))
        
        # --- 4. Case 1 (祖先-子孫関係) の処理 ---
        S_v = subtree_sum[v]
        S_rest = total_sum - S_v
        
        # S_B と S_v - S_B が近くなる S_B (S_B approx S_v / 2) を探す
        # (O(log N_v))
        target = S_v / 2
        idx = bisect_left(my_desc_sums, target)
        for k in [idx, idx - 1]:
            if 0 <= k < len(my_desc_sums):
                S_B = my_desc_sums[k]
                S_A = S_v - S_B
                check_min_diff(S_A, S_B, S_rest)

        # --- 5. 親にリストを渡す (O(N_v log N_v) のボトルネック) ---
        if p != -1:
            # O(N^2) になる O(N_v) の挿入を避ける
            # my_desc_sums = list(heapq.merge(my_desc_sums, [S_v]))
            
            # O(N log^2 N) になる O(N_v log N_v) の挿入
            my_desc_sums.append(S_v)
            my_desc_sums.sort()
            
        return my_desc_sums

    # 根 (頂点0) から探索開始
    dfs_solve(0, -1)

    print(min_diff[0])

if __name__ == "__main__":
    solve()
0