結果

問題 No.3194 Do Optimize Your Solution
ユーザー noya2
提出日時 2025-06-25 04:11:45
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,122 ms / 3,000 ms
コード長 9,286 bytes
コンパイル時間 3,558 ms
コンパイル使用メモリ 294,544 KB
実行使用メモリ 74,576 KB
最終ジャッジ日時 2025-06-27 20:53:52
合計ジャッジ時間 14,649 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
#define rep(i,s,n) for (int i = (int)(s); i < (int)(n); i++)
using namespace std;
using ull = unsigned long long;
const int mx = 19;

int input(){
    int x = 0;
    char c;
    while ((c = getchar_unlocked() ^ 48) < 10){
        x = x * 10 + c;
    }
    return x;
}

struct Timer {
    chrono::high_resolution_clock clc;
    chrono::high_resolution_clock::time_point begin, end;
    void start(){
        begin = clc.now();
    }
    void stop(){
        end = clc.now();
    }
    ull duration(){
        return chrono::duration_cast<chrono::microseconds>(end - begin).count();
    }
};

template<class E>
struct csr {
    csr () {}
    csr (int _n) : n(_n) {}
    csr (int _n, int m) : n(_n){
        start.reserve(m);
        elist.reserve(m);
    }
    void add(int idx, E elem){
        start.emplace_back(idx);
        elist.emplace_back(elem);
    }
    void build(){
        int m = start.size();
        std::vector<E> nelist(m);
        std::vector<int> nstart(n + 2, 0);
        for (int i = 0; i < m; i++){
            nstart[start[i] + 2]++;
        }
        for (int i = 1; i < n; i++){
            nstart[i + 2] += nstart[i + 1];
        }
        for (int i = 0; i < m; i++){
            nelist[nstart[start[i] + 1]++] = elist[i];
        }
        swap(elist,nelist);
        swap(start,nstart);
    }
    const auto operator[](int idx) const {
        return std::ranges::subrange(elist.begin()+start[idx],elist.begin()+start[idx+1]);
    }
    auto operator[](int idx){
        return std::ranges::subrange(elist.begin()+start[idx],elist.begin()+start[idx+1]);
    }
    size_t size() const {
        return n;
    }
    int n;
    std::vector<int> start;
    std::vector<E> elist;
};

int main(){
    Timer tot;
    tot.start();
    Timer timer;
    ull spend_input = 0;
    timer.start();
    cin.tie(0)->sync_with_stdio(0);
    int n = input();
    csr<int> a(n,2*(n-1)), b(n,2*(n-1));
    rep(i,0,n-1){
        int u = input(); u--;
        int v = input(); v--;
        a.add(u,v);
        a.add(v,u);
    }
    rep(i,0,n-1){
        int u = input(); u--;
        int v = input(); v--;
        b.add(u,v);
        b.add(v,u);
    }
    a.build();
    b.build();
    timer.stop();
    spend_input = timer.duration();

    ull spend_cd = 0;
    timer.start();
    vector<bool> done(n,false);
    vector<int> par(n);
    vector<array<int,mx>> dist(n);
    struct dat {
        int cnt;
        ull sum;
    };
    vector<dat> dp(n), ep(n);
    vector<int> dep(n);

    // auto subtree = [&](auto _subtree, int v, int f) -> void {
    //     dp[v].cnt = 1;
    //     dp[v].sum = 0;
    //     for (int u : a[v]){
    //         if (u == f) continue;
    //         _subtree(_subtree, u, v);
    //         dp[v].cnt += dp[u].cnt;
    //         dp[v].sum += dp[u].cnt + dp[u].sum;
    //     }
    // };
    // subtree(subtree,0,-1);

    // auto fixed_root = [&](auto _fixed_root, int root, int from, int size, ull dsum, int cpre, int recdep) -> void {
    //     auto find_centroid = [&](auto _find_centroid, int v, int f, ull other) -> dat {
    //         int heavy = 0, child = -1;
    //         for (int u : a[v]){
    //             if (u == f) continue;
    //             if (heavy < dp[u].cnt){
    //                 heavy = dp[u].cnt;
    //                 child = u;
    //             }
    //         }
    //         if (heavy > size/2){
    //             for (int u : a[v]){
    //                 if (u == f) continue;
    //                 if (u == child) continue;
    //                 other += dp[u].sum;
    //             }
    //             other += size - heavy;
    //             auto [del0, del1] = _find_centroid(_find_centroid, child, v, other);
    //             del1 += del0;
    //             dp[v].cnt -= del0;
    //             dp[v].sum -= del1;
    //             return {del0, del1};
    //         }
    //         else {
    //             par[v] = cpre;
    //             dep[v] = recdep++;
    //             ep[v].cnt = size;
    //             ep[v].sum = dsum;
    //             for (int u : a[v]){
    //                 if (u == f) continue;
    //                 _fixed_root(_fixed_root, u, v, dp[u].cnt, dp[u].sum, v, recdep);
    //             }
    //             int del0 = dp[v].cnt;
    //             ull del1 = dp[v].sum;
    //             cpre = v;
    //             dsum = other;
    //             dp[v].cnt = 0;
    //             dp[v].sum += other;
    //             return {del0, del1};
    //         }
    //     };
    //     while (size > 0){
    //         size -= find_centroid(find_centroid, root, from, 0).cnt;
    //     }
    // };
    // fixed_root(fixed_root, 0, -1, n, 0, -1, 0);
    // rep(i,0,n){
    //     dp[i].cnt = ep[i].cnt;
    // }

    auto cd = [&](auto _cd, int one, int sz, int recdep) -> pair<int,ull> {
        int ctr = -1, ctrpar = -1;
        ull dsum = 0;
        auto subsz = [&](auto _subsz, int v, int f, int d) -> void {
            dp[v].cnt = 1;
            dist[v][recdep] = d;
            dsum += d;
            for (int u : a[v]){
                if (done[u]) continue;
                if (u == f) continue;
                _subsz(_subsz,u,v,d+1);
                dp[v].cnt += dp[u].cnt;
            }
            if (dp[v].cnt*2 >= sz){
                if (ctr == -1){
                    ctr = v;
                    ctrpar = f;
                }
            }
        };
        subsz(subsz,one,-1,1);
        dep[ctr] = recdep;
        done[ctr] = true;
        ull ssum = 0;
        for (int none : a[ctr]){
            if (done[none]) continue;
            // int c0 = dp[none].cnt;
            int c0 = (none == ctrpar ? sz - dp[ctr].cnt : dp[none].cnt);
            // int c0 = 0;
            // ull s0 = 0;
            // auto dfs = [&](auto _dfs, int v, int f, int d) -> void {
            //     c0 += 1;
            //     s0 += d;
            //     dist[v][recdep] = d;
            //     for (int u : a[v]){
            //         if (done[u]) continue;
            //         if (u == f) continue;
            //         _dfs(_dfs,u,v,d+1);
            //     }
            // };
            // dfs(dfs,none,ctr,1);
            auto [ch, s0] = _cd(_cd,none,c0,recdep+1);
            par[ch] = ctr;
            ep[ch] = {c0,s0};
            ssum += s0;
        }
        dp[ctr] = {sz, ssum};
        return {ctr, dsum};
    };
    cd(cd,0,n,0);
    timer.stop();
    spend_cd += timer.duration();
    // sum01 = sum[u : 0][v : 1] dist(u,v)
    //       = sum[ctr][u : 0][v : 1] (dist_from_ctr[u] + dist_from_ctr[v])
    //       = sum[ctr] ( cnt1[ctr] * sum0[ctr] + cnt0[ctr] * sum1[ctr] )
    ull sum01 = 0;

    ull spend_query = 0;
    timer.start();
    auto turn_on = [&](int v){
        sum01 += dp[v].sum;
        dp[v].cnt -= 2;
        int recdep = dep[v];
        int c = v;
        while (recdep > 0){
            int p = par[c];
            recdep--;
            int d = dist[v][recdep+1];
            sum01 += dp[p].sum - ep[c].sum;
            sum01 += ull(dp[p].cnt - ep[c].cnt) * d;
            ep[c].cnt -= 2;
            ep[c].sum -= 2*d;
            dp[p].cnt -= 2;
            dp[p].sum -= 2*d;
            c = p;
        }
    };
    auto turn_off = [&](int v){
        sum01 -= dp[v].sum;
        dp[v].cnt += 2;
        int recdep = dep[v];
        int c = v;
        while (recdep > 0){
            int p = par[c];
            recdep--;
            int d = dist[v][recdep+1];
            sum01 -= dp[p].sum - ep[c].sum;
            sum01 -= ull(dp[p].cnt - ep[c].cnt) * d;
            ep[c].cnt += 2;
            ep[c].sum += 2*d;
            dp[p].cnt += 2;
            dp[p].sum += 2*d;
            c = p;
        }
    };
    ull ans = 0;
    vector<int> sub(n), down(n), tour(n);
    int t = 0;
    auto subsz = [&](auto _subsz, int v, int f) -> void {
        sub[v] = 1;
        tour[t] = v;
        down[v] = t++;
        for (int u : b[v]){
            if (u == f) continue;
            _subsz(_subsz,u,v);
            sub[v] += sub[u];
        }
    };
    subsz(subsz,0,-1);
    auto dfs = [&](auto _dfs, int v, int f, bool top) -> void {
        int heavy = -1, ch = -1;
        for (int u : b[v]){
            if (u == f) continue;
            if (heavy < sub[u]){
                heavy = sub[u];
                ch = u;
            }
        }
        for (int u : b[v]){
            if (u == f) continue;
            if (u == ch) continue;
            _dfs(_dfs,u,v,true);
        }
        if (ch != -1){
            _dfs(_dfs,ch,v,false);
        }
        if (v == 0) return ;
        for (int u : b[v]){
            if (u == f) continue;
            if (u == ch) continue;
            for (int i = down[u]; i < down[u]+sub[u]; i++){
                turn_on(tour[i]);
            }
        }
        turn_on(v);
        ans += sum01;
        if (top){
            for (int i = down[v]; i < down[v]+sub[v]; i++){
                turn_off(tour[i]);
            }
        }
    };
    dfs(dfs,0,-1,true);
    timer.stop();
    spend_query += timer.duration();
    ans *= 2;
    cout << ans << endl;
    tot.stop();
    return 0;
    cout << spend_input << endl;
    cout << spend_cd << endl;
    cout << spend_query << endl;
    cout << tot.duration() << endl;
}
0