結果

問題 No.3194 Do Optimize Your Solution
ユーザー noya2
提出日時 2025-06-25 00:13:21
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,841 ms / 3,000 ms
コード長 5,353 bytes
コンパイル時間 3,458 ms
コンパイル使用メモリ 294,164 KB
実行使用メモリ 82,316 KB
最終ジャッジ日時 2025-06-27 20:53:24
合計ジャッジ時間 21,853 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
#define rep(i,s,n) for (int i = (int)(s); i < (int)(n); i++)
using namespace std;
using ull = unsigned long long;
const int mx = 19;

int input(){
    int x = 0;
    char c;
    while ((c = getchar_unlocked() ^ 48) < 10){
        x = x * 10 + c;
    }
    return x;
}

int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n = input();
    vector<vector<int>> a(n), b(n);
    rep(i,0,n-1){
        int u = input(); u--;
        int v = input(); v--;
        a[u].emplace_back(v);
        a[v].emplace_back(u);
    }
    rep(i,0,n-1){
        int u = input(); u--;
        int v = input(); v--;
        b[u].emplace_back(v);
        b[v].emplace_back(u);
    }
    vector<bool> done(n,false);
    vector<int> par(n);
    vector<array<int,mx>> dist(n);
    vector<array<ull,4>> dp(n), ep(n);
    vector<int> dep(n);
    auto cd = [&](auto _cd, int one, int sz, int recdep) -> int {
        int ctr = -1;
        auto subsz = [&](auto _subsz, int v, int f) -> int {
            int ret = 1;
            for (int u : a[v]){
                if (done[u]) continue;
                if (u == f) continue;
                ret += _subsz(_subsz,u,v);
            }
            if (ret*2 >= sz){
                if (ctr == -1){
                    ctr = v;
                }
            }
            return ret;
        };
        subsz(subsz,one,-1);
        assert(ctr != -1);
        dep[ctr] = recdep;
        done[ctr] = true;
        ull csum = 1, ssum = 0;
        for (int none : a[ctr]){
            if (done[none]) continue;
            ull c0 = 0, s0 = 0;
            auto dfs = [&](auto _dfs, int v, int f, int d) -> void {
                c0 += 1;
                s0 += d;
                dist[v][recdep] = d;
                for (int u : a[v]){
                    if (done[u]) continue;
                    if (u == f) continue;
                    _dfs(_dfs,u,v,d+1);
                }
            };
            dfs(dfs,none,ctr,1);
            int ch = _cd(_cd,none,c0,recdep+1);
            par[ch] = ctr;
            ep[ch] = {c0,s0,0,0};
            csum += c0;
            ssum += s0;
        }
        dp[ctr] = {csum, ssum, 0, 0};
        return ctr;
    };
    int root = cd(cd,0,n,0);
    par[root] = -1;
    vector<bool> on(n,false);
    // sum01 = sum[u : 0][v : 1] dist(u,v)
    //       = sum[ctr][u : 0][v : 1] dist_from_ctr[u] + dist_from_ctr[v]
    //       = sum[ctr] { cnt1[ctr] * sum0[ctr] + cnt0[ctr] * sum1[ctr] }
    ull sum01 = 0;
    auto turn_on = [&](int v){
        assert(!on[v]);
        on[v] = true;
        sum01 -= dp[v][3];
        dp[v][0]--;
        dp[v][2]++;
        sum01 += dp[v][1];
        int recdep = dep[v];
        int c = v;
        while (recdep > 0){
            int p = par[c];
            recdep--;
            int d = dist[v][recdep];
            sum01 -= (dp[p][3] - ep[c][3]);
            sum01 -= (dp[p][2] - ep[c][2]) * d;
            ep[c][0]--;
            ep[c][1] -= d;
            ep[c][2]++;
            ep[c][3] += d;
            dp[p][0]--;
            dp[p][1] -= d;
            dp[p][2]++;
            dp[p][3] += d;
            sum01 += (dp[p][1] - ep[c][1]);
            sum01 += (dp[p][0] - ep[c][0]) * d;
            c = p;
        }
    };
    auto turn_off = [&](int v){
        assert(on[v]);
        on[v] = false;
        sum01 -= dp[v][1];
        dp[v][2]--;
        dp[v][0]++;
        sum01 += dp[v][3];
        int recdep = dep[v];
        int c = v;
        while (recdep > 0){
            int p = par[c];
            recdep--;
            int d = dist[v][recdep];
            sum01 -= (dp[p][1] - ep[c][1]);
            sum01 -= (dp[p][0] - ep[c][0]) * d;
            ep[c][0]++;
            ep[c][1] += d;
            ep[c][2]--;
            ep[c][3] -= d;
            dp[p][0]++;
            dp[p][1] += d;
            dp[p][2]--;
            dp[p][3] -= d;
            sum01 += (dp[p][3] - ep[c][3]);
            sum01 += (dp[p][2] - ep[c][2]) * d;
            c = p;
        }
    };
    ull ans = 0;
    vector<int> sub(n);
    vector<int> down(n), tour(n);
    int t = 0;
    auto subsz = [&](auto _subsz, int v, int f) -> void {
        sub[v] = 1;
        tour[t] = v;
        down[v] = t++;
        for (int u : b[v]){
            if (u == f) continue;
            _subsz(_subsz,u,v);
            sub[v] += sub[u];
        }
    };
    subsz(subsz,0,-1);
    auto dfs = [&](auto _dfs, int v, int f, bool top) -> void {
        int heavy = -1, ch = -1;
        for (int u : b[v]){
            if (u == f) continue;
            if (heavy < sub[u]){
                heavy = sub[u];
                ch = u;
            }
        }
        for (int u : b[v]){
            if (u == f) continue;
            if (u == ch) continue;
            _dfs(_dfs,u,v,true);
        }
        if (ch != -1){
            _dfs(_dfs,ch,v,false);
        }
        if (v == 0) return ;
        for (int u : b[v]){
            if (u == f) continue;
            if (u == ch) continue;
            for (int i = down[u]; i < down[u]+sub[u]; i++){
                turn_on(tour[i]);
            }
        }
        turn_on(v);
        ans += sum01;
        if (top){
            for (int i = down[v]; i < down[v]+sub[v]; i++){
                turn_off(tour[i]);
            }
        }
    };
    dfs(dfs,0,-1,true);
    ans *= 2;
    cout << ans << endl;
}
0