結果

問題 No.3348 Tree Balance
コンテスト
ユーザー ZOI-dayo
提出日時 2025-10-29 16:46:13
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 3,861 bytes
コンパイル時間 1,891 ms
コンパイル使用メモリ 151,428 KB
実行使用メモリ 38,096 KB
最終ジャッジ日時 2025-11-13 21:04:20
合計ジャッジ時間 14,505 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <functional>
#include <climits>
#include <cmath>

using namespace std;

using ll = long long;

ll total_sum = 0;
ll min_diff = LLONG_MAX;
vector<ll> W;
vector<vector<int>> adj;
vector<ll> subtree_sum;

void check_min_diff(ll S1, ll S2, ll S3) {
    if (S1 <= 0 || S2 <= 0 || S3 <= 0) {
        return;
    }
    ll max_S = max({S1, S2, S3});
    ll min_S = min({S1, S2, S3});
    ll diff = max_S - min_S;
    if (diff < min_diff) {
        min_diff = diff;
    }
}

void dfs_precompute(int v, int p) {
    ll s = W[v];
    for (int u : adj[v]) {
        if (u == p) continue;
        dfs_precompute(u, v);
        s += subtree_sum[u];
    }
    subtree_sum[v] = s;
}

vector<ll> dfs_solve(int v, int p) {
    
    vector<vector<ll>> child_lists;
    for (int u : adj[v]) {
        if (u == p) continue;
        child_lists.push_back(dfs_solve(u, v));
    }

    if (child_lists.empty()) {
        if (p != -1) {
            return {subtree_sum[v]};
        } else {
            return {};
        }
    }

    sort(child_lists.begin(), child_lists.end(), [](const vector<ll>& a, const vector<ll>& b) {
        return a.size() < b.size();
    });

    vector<ll> my_desc_sums = std::move(child_lists.back());
    child_lists.pop_back();

    for (const auto& list_A : child_lists) {
        const vector<ll>& list_B = my_desc_sums;
        
        if (!list_A.empty() && !list_B.empty()) {
            int ptr_b = list_B.size() - 1;
            for (ll S_A : list_A) {
                // (total_sum - S_A) / 2.0 に最も近いS_Bを探す
                double target = (double)(total_sum - S_A) / 2.0;
                
                while (ptr_b > 0 && list_B[ptr_b] > target) {
                    ptr_b--;
                }
                
                for (int k = ptr_b; k <= ptr_b + 1; ++k) {
                    if (k >= 0 && k < list_B.size()) {
                        ll S_B = list_B[k];
                        ll S_C = total_sum - S_A - S_B;
                        check_min_diff(S_A, S_B, S_C);
                    }
                }
            }
        }

        vector<ll> merged_list;
        merged_list.reserve(my_desc_sums.size() + list_A.size());
        std::merge(my_desc_sums.begin(), my_desc_sums.end(),
                   list_A.begin(), list_A.end(),
                   std::back_inserter(merged_list));
        my_desc_sums = std::move(merged_list);
    }

    ll S_v = subtree_sum[v];
    ll S_rest = total_sum - S_v;
    
    double target = (double)S_v / 2.0;
    
    auto it = std::lower_bound(my_desc_sums.begin(), my_desc_sums.end(), target);
    
    if (it != my_desc_sums.end()) {
        ll S_B = *it;
        ll S_A = S_v - S_B;
        check_min_diff(S_A, S_B, S_rest);
    }
    if (it != my_desc_sums.begin()) {
        ll S_B = *std::prev(it);
        ll S_A = S_v - S_B;
        check_min_diff(S_A, S_B, S_rest);
    }


    if (p != -1) {
        my_desc_sums.push_back(S_v);
        std::sort(my_desc_sums.begin(), my_desc_sums.end());
    }
    
    return my_desc_sums;
}

void solve() {
    int N;
    if (!(cin >> N)) {
        return; // EOF
    }
    
    W.resize(N);
    total_sum = 0;
    for (int i = 0; i < N; ++i) {
        cin >> W[i];
        total_sum += W[i];
    }
    
    adj.assign(N, vector<int>());
    for (int i = 0; i < N - 1; ++i) {
        int A, B;
        cin >> A >> B;
        int u = A - 1, v = B - 1;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    if (N < 3) {
        cout << 0 << endl;
        return;
    }

    subtree_sum.assign(N, 0);
    min_diff = LLONG_MAX;

    dfs_precompute(0, -1);
    
    dfs_solve(0, -1);

    cout << min_diff << endl;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    solve();
    
    return 0;
}
0