結果
問題 |
No.3194 Do Optimize Your Solution
|
ユーザー |
![]() |
提出日時 | 2025-06-25 00:13:21 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,841 ms / 3,000 ms |
コード長 | 5,353 bytes |
コンパイル時間 | 3,458 ms |
コンパイル使用メモリ | 294,164 KB |
実行使用メモリ | 82,316 KB |
最終ジャッジ日時 | 2025-06-27 20:53:24 |
合計ジャッジ時間 | 21,853 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 17 |
ソースコード
#include<bits/stdc++.h> #define rep(i,s,n) for (int i = (int)(s); i < (int)(n); i++) using namespace std; using ull = unsigned long long; const int mx = 19; int input(){ int x = 0; char c; while ((c = getchar_unlocked() ^ 48) < 10){ x = x * 10 + c; } return x; } int main(){ cin.tie(0)->sync_with_stdio(0); int n = input(); vector<vector<int>> a(n), b(n); rep(i,0,n-1){ int u = input(); u--; int v = input(); v--; a[u].emplace_back(v); a[v].emplace_back(u); } rep(i,0,n-1){ int u = input(); u--; int v = input(); v--; b[u].emplace_back(v); b[v].emplace_back(u); } vector<bool> done(n,false); vector<int> par(n); vector<array<int,mx>> dist(n); vector<array<ull,4>> dp(n), ep(n); vector<int> dep(n); auto cd = [&](auto _cd, int one, int sz, int recdep) -> int { int ctr = -1; auto subsz = [&](auto _subsz, int v, int f) -> int { int ret = 1; for (int u : a[v]){ if (done[u]) continue; if (u == f) continue; ret += _subsz(_subsz,u,v); } if (ret*2 >= sz){ if (ctr == -1){ ctr = v; } } return ret; }; subsz(subsz,one,-1); assert(ctr != -1); dep[ctr] = recdep; done[ctr] = true; ull csum = 1, ssum = 0; for (int none : a[ctr]){ if (done[none]) continue; ull c0 = 0, s0 = 0; auto dfs = [&](auto _dfs, int v, int f, int d) -> void { c0 += 1; s0 += d; dist[v][recdep] = d; for (int u : a[v]){ if (done[u]) continue; if (u == f) continue; _dfs(_dfs,u,v,d+1); } }; dfs(dfs,none,ctr,1); int ch = _cd(_cd,none,c0,recdep+1); par[ch] = ctr; ep[ch] = {c0,s0,0,0}; csum += c0; ssum += s0; } dp[ctr] = {csum, ssum, 0, 0}; return ctr; }; int root = cd(cd,0,n,0); par[root] = -1; vector<bool> on(n,false); // sum01 = sum[u : 0][v : 1] dist(u,v) // = sum[ctr][u : 0][v : 1] dist_from_ctr[u] + dist_from_ctr[v] // = sum[ctr] { cnt1[ctr] * sum0[ctr] + cnt0[ctr] * sum1[ctr] } ull sum01 = 0; auto turn_on = [&](int v){ assert(!on[v]); on[v] = true; sum01 -= dp[v][3]; dp[v][0]--; dp[v][2]++; sum01 += dp[v][1]; int recdep = dep[v]; int c = v; while (recdep > 0){ int p = par[c]; recdep--; int d = dist[v][recdep]; sum01 -= (dp[p][3] - ep[c][3]); sum01 -= (dp[p][2] - ep[c][2]) * d; ep[c][0]--; ep[c][1] -= d; ep[c][2]++; ep[c][3] += d; dp[p][0]--; dp[p][1] -= d; dp[p][2]++; dp[p][3] += d; sum01 += (dp[p][1] - ep[c][1]); sum01 += (dp[p][0] - ep[c][0]) * d; c = p; } }; auto turn_off = [&](int v){ assert(on[v]); on[v] = false; sum01 -= dp[v][1]; dp[v][2]--; dp[v][0]++; sum01 += dp[v][3]; int recdep = dep[v]; int c = v; while (recdep > 0){ int p = par[c]; recdep--; int d = dist[v][recdep]; sum01 -= (dp[p][1] - ep[c][1]); sum01 -= (dp[p][0] - ep[c][0]) * d; ep[c][0]++; ep[c][1] += d; ep[c][2]--; ep[c][3] -= d; dp[p][0]++; dp[p][1] += d; dp[p][2]--; dp[p][3] -= d; sum01 += (dp[p][3] - ep[c][3]); sum01 += (dp[p][2] - ep[c][2]) * d; c = p; } }; ull ans = 0; vector<int> sub(n); vector<int> down(n), tour(n); int t = 0; auto subsz = [&](auto _subsz, int v, int f) -> void { sub[v] = 1; tour[t] = v; down[v] = t++; for (int u : b[v]){ if (u == f) continue; _subsz(_subsz,u,v); sub[v] += sub[u]; } }; subsz(subsz,0,-1); auto dfs = [&](auto _dfs, int v, int f, bool top) -> void { int heavy = -1, ch = -1; for (int u : b[v]){ if (u == f) continue; if (heavy < sub[u]){ heavy = sub[u]; ch = u; } } for (int u : b[v]){ if (u == f) continue; if (u == ch) continue; _dfs(_dfs,u,v,true); } if (ch != -1){ _dfs(_dfs,ch,v,false); } if (v == 0) return ; for (int u : b[v]){ if (u == f) continue; if (u == ch) continue; for (int i = down[u]; i < down[u]+sub[u]; i++){ turn_on(tour[i]); } } turn_on(v); ans += sum01; if (top){ for (int i = down[v]; i < down[v]+sub[v]; i++){ turn_off(tour[i]); } } }; dfs(dfs,0,-1,true); ans *= 2; cout << ans << endl; }