結果
問題 |
No.3194 Do Optimize Your Solution
|
ユーザー |
![]() |
提出日時 | 2025-06-25 04:11:45 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,122 ms / 3,000 ms |
コード長 | 9,286 bytes |
コンパイル時間 | 3,558 ms |
コンパイル使用メモリ | 294,544 KB |
実行使用メモリ | 74,576 KB |
最終ジャッジ日時 | 2025-06-27 20:53:52 |
合計ジャッジ時間 | 14,649 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 17 |
ソースコード
#include<bits/stdc++.h> #define rep(i,s,n) for (int i = (int)(s); i < (int)(n); i++) using namespace std; using ull = unsigned long long; const int mx = 19; int input(){ int x = 0; char c; while ((c = getchar_unlocked() ^ 48) < 10){ x = x * 10 + c; } return x; } struct Timer { chrono::high_resolution_clock clc; chrono::high_resolution_clock::time_point begin, end; void start(){ begin = clc.now(); } void stop(){ end = clc.now(); } ull duration(){ return chrono::duration_cast<chrono::microseconds>(end - begin).count(); } }; template<class E> struct csr { csr () {} csr (int _n) : n(_n) {} csr (int _n, int m) : n(_n){ start.reserve(m); elist.reserve(m); } void add(int idx, E elem){ start.emplace_back(idx); elist.emplace_back(elem); } void build(){ int m = start.size(); std::vector<E> nelist(m); std::vector<int> nstart(n + 2, 0); for (int i = 0; i < m; i++){ nstart[start[i] + 2]++; } for (int i = 1; i < n; i++){ nstart[i + 2] += nstart[i + 1]; } for (int i = 0; i < m; i++){ nelist[nstart[start[i] + 1]++] = elist[i]; } swap(elist,nelist); swap(start,nstart); } const auto operator[](int idx) const { return std::ranges::subrange(elist.begin()+start[idx],elist.begin()+start[idx+1]); } auto operator[](int idx){ return std::ranges::subrange(elist.begin()+start[idx],elist.begin()+start[idx+1]); } size_t size() const { return n; } int n; std::vector<int> start; std::vector<E> elist; }; int main(){ Timer tot; tot.start(); Timer timer; ull spend_input = 0; timer.start(); cin.tie(0)->sync_with_stdio(0); int n = input(); csr<int> a(n,2*(n-1)), b(n,2*(n-1)); rep(i,0,n-1){ int u = input(); u--; int v = input(); v--; a.add(u,v); a.add(v,u); } rep(i,0,n-1){ int u = input(); u--; int v = input(); v--; b.add(u,v); b.add(v,u); } a.build(); b.build(); timer.stop(); spend_input = timer.duration(); ull spend_cd = 0; timer.start(); vector<bool> done(n,false); vector<int> par(n); vector<array<int,mx>> dist(n); struct dat { int cnt; ull sum; }; vector<dat> dp(n), ep(n); vector<int> dep(n); // auto subtree = [&](auto _subtree, int v, int f) -> void { // dp[v].cnt = 1; // dp[v].sum = 0; // for (int u : a[v]){ // if (u == f) continue; // _subtree(_subtree, u, v); // dp[v].cnt += dp[u].cnt; // dp[v].sum += dp[u].cnt + dp[u].sum; // } // }; // subtree(subtree,0,-1); // auto fixed_root = [&](auto _fixed_root, int root, int from, int size, ull dsum, int cpre, int recdep) -> void { // auto find_centroid = [&](auto _find_centroid, int v, int f, ull other) -> dat { // int heavy = 0, child = -1; // for (int u : a[v]){ // if (u == f) continue; // if (heavy < dp[u].cnt){ // heavy = dp[u].cnt; // child = u; // } // } // if (heavy > size/2){ // for (int u : a[v]){ // if (u == f) continue; // if (u == child) continue; // other += dp[u].sum; // } // other += size - heavy; // auto [del0, del1] = _find_centroid(_find_centroid, child, v, other); // del1 += del0; // dp[v].cnt -= del0; // dp[v].sum -= del1; // return {del0, del1}; // } // else { // par[v] = cpre; // dep[v] = recdep++; // ep[v].cnt = size; // ep[v].sum = dsum; // for (int u : a[v]){ // if (u == f) continue; // _fixed_root(_fixed_root, u, v, dp[u].cnt, dp[u].sum, v, recdep); // } // int del0 = dp[v].cnt; // ull del1 = dp[v].sum; // cpre = v; // dsum = other; // dp[v].cnt = 0; // dp[v].sum += other; // return {del0, del1}; // } // }; // while (size > 0){ // size -= find_centroid(find_centroid, root, from, 0).cnt; // } // }; // fixed_root(fixed_root, 0, -1, n, 0, -1, 0); // rep(i,0,n){ // dp[i].cnt = ep[i].cnt; // } auto cd = [&](auto _cd, int one, int sz, int recdep) -> pair<int,ull> { int ctr = -1, ctrpar = -1; ull dsum = 0; auto subsz = [&](auto _subsz, int v, int f, int d) -> void { dp[v].cnt = 1; dist[v][recdep] = d; dsum += d; for (int u : a[v]){ if (done[u]) continue; if (u == f) continue; _subsz(_subsz,u,v,d+1); dp[v].cnt += dp[u].cnt; } if (dp[v].cnt*2 >= sz){ if (ctr == -1){ ctr = v; ctrpar = f; } } }; subsz(subsz,one,-1,1); dep[ctr] = recdep; done[ctr] = true; ull ssum = 0; for (int none : a[ctr]){ if (done[none]) continue; // int c0 = dp[none].cnt; int c0 = (none == ctrpar ? sz - dp[ctr].cnt : dp[none].cnt); // int c0 = 0; // ull s0 = 0; // auto dfs = [&](auto _dfs, int v, int f, int d) -> void { // c0 += 1; // s0 += d; // dist[v][recdep] = d; // for (int u : a[v]){ // if (done[u]) continue; // if (u == f) continue; // _dfs(_dfs,u,v,d+1); // } // }; // dfs(dfs,none,ctr,1); auto [ch, s0] = _cd(_cd,none,c0,recdep+1); par[ch] = ctr; ep[ch] = {c0,s0}; ssum += s0; } dp[ctr] = {sz, ssum}; return {ctr, dsum}; }; cd(cd,0,n,0); timer.stop(); spend_cd += timer.duration(); // sum01 = sum[u : 0][v : 1] dist(u,v) // = sum[ctr][u : 0][v : 1] (dist_from_ctr[u] + dist_from_ctr[v]) // = sum[ctr] ( cnt1[ctr] * sum0[ctr] + cnt0[ctr] * sum1[ctr] ) ull sum01 = 0; ull spend_query = 0; timer.start(); auto turn_on = [&](int v){ sum01 += dp[v].sum; dp[v].cnt -= 2; int recdep = dep[v]; int c = v; while (recdep > 0){ int p = par[c]; recdep--; int d = dist[v][recdep+1]; sum01 += dp[p].sum - ep[c].sum; sum01 += ull(dp[p].cnt - ep[c].cnt) * d; ep[c].cnt -= 2; ep[c].sum -= 2*d; dp[p].cnt -= 2; dp[p].sum -= 2*d; c = p; } }; auto turn_off = [&](int v){ sum01 -= dp[v].sum; dp[v].cnt += 2; int recdep = dep[v]; int c = v; while (recdep > 0){ int p = par[c]; recdep--; int d = dist[v][recdep+1]; sum01 -= dp[p].sum - ep[c].sum; sum01 -= ull(dp[p].cnt - ep[c].cnt) * d; ep[c].cnt += 2; ep[c].sum += 2*d; dp[p].cnt += 2; dp[p].sum += 2*d; c = p; } }; ull ans = 0; vector<int> sub(n), down(n), tour(n); int t = 0; auto subsz = [&](auto _subsz, int v, int f) -> void { sub[v] = 1; tour[t] = v; down[v] = t++; for (int u : b[v]){ if (u == f) continue; _subsz(_subsz,u,v); sub[v] += sub[u]; } }; subsz(subsz,0,-1); auto dfs = [&](auto _dfs, int v, int f, bool top) -> void { int heavy = -1, ch = -1; for (int u : b[v]){ if (u == f) continue; if (heavy < sub[u]){ heavy = sub[u]; ch = u; } } for (int u : b[v]){ if (u == f) continue; if (u == ch) continue; _dfs(_dfs,u,v,true); } if (ch != -1){ _dfs(_dfs,ch,v,false); } if (v == 0) return ; for (int u : b[v]){ if (u == f) continue; if (u == ch) continue; for (int i = down[u]; i < down[u]+sub[u]; i++){ turn_on(tour[i]); } } turn_on(v); ans += sum01; if (top){ for (int i = down[v]; i < down[v]+sub[v]; i++){ turn_off(tour[i]); } } }; dfs(dfs,0,-1,true); timer.stop(); spend_query += timer.duration(); ans *= 2; cout << ans << endl; tot.stop(); return 0; cout << spend_input << endl; cout << spend_cd << endl; cout << spend_query << endl; cout << tot.duration() << endl; }