結果
問題 |
No.3194 Do Optimize Your Solution
|
ユーザー |
![]() |
提出日時 | 2025-06-24 19:43:51 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
RE
|
実行時間 | - |
コード長 | 4,762 bytes |
コンパイル時間 | 3,783 ms |
コンパイル使用メモリ | 301,008 KB |
実行使用メモリ | 122,668 KB |
最終ジャッジ日時 | 2025-06-27 20:52:24 |
合計ジャッジ時間 | 12,328 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 1 WA * 1 RE * 2 TLE * 8 -- * 5 |
ソースコード
#include <bits/stdc++.h> #define sz(v) ((int)(v).size()) #define all(v) (v).begin(), (v).end() #define ws fuckinfoasdjfosadjljo using namespace std; using lint = long long; using ull = unsigned long long; using pi = pair<int, int>; const int MAXN = 200005; // const int mod = 1e9 + 7; int n; vector<pi> gph[2][MAXN]; int lvl[MAXN], lg[MAXN * 2], dep[MAXN], din[MAXN], dout[MAXN], piv; bool vis[MAXN]; pi spt[18][MAXN * 2]; bool in(int x, int y){ return din[x] <= din[y] && dout[y] <= dout[x]; } void dfs1(int x, int p){ din[x] = piv++; if(p > 0) spt[0][piv - 1] = pi(lvl[p], p); for(auto &fuck : gph[0][x]){ int u = fuck.first; int v = fuck.second; if(v != p){ dep[v] = dep[x] + u; lvl[v] = lvl[x] + 1; dfs1(v, x); } } dout[x] = piv++; if(p > 0) spt[0][piv - 1] = pi(lvl[p], p); } int lca(int x, int y){ if(din[x] > din[y]) swap(x, y); if(in(x, y)) return x; int s = dout[x], e = din[y]; int l = lg[e - s + 1]; return min(spt[l][s], spt[l][e - (1<<l) + 1]).second; } int col[MAXN], w[MAXN], pae[MAXN]; int cs[MAXN], ws[MAXN]; vector<pi> cmp[MAXN]; vector<int> ord; void dfs4(int x, int p){ ord.push_back(x); cs[x] = col[x]; ws[x] = col[x] ?w[x] : 0; for(auto &fuck : cmp[x]){ int u = fuck.first; int v = fuck.second; if(v != p){ pae[v] = u; dfs4(v, x); cs[x] += cs[v]; ws[x] += ws[v]; } } } lint tree_comp(vector<int> v){ for(auto &i : v) col[i] = 1; sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; }); for(int i=sz(v)-1; i>0; i--) v.push_back(lca(v[i-1], v[i])); sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; }); v.resize(unique(all(v)) - v.begin()); { vector<int> stk; for(auto &i : v){ while(sz(stk) && !in(stk.back(), i)) stk.pop_back(); if(sz(stk)){ int p = stk.back(); int dist = (dep[i] - dep[p]); cmp[p].emplace_back(dist, i); cmp[i].emplace_back(dist, p); } stk.push_back(i); } } dfs4(v[0], -1); int sumC = cs[v[0]]; int sumW = ws[v[0]]; lint ret = 0; for(auto &i : ord){ ret += (1ll * ws[i] * (sumC - cs[i])) * pae[i]; ret += (1ll * cs[i] * (sumW - ws[i])) * pae[i]; } ord.clear(); for(auto &i : v){ cmp[i].clear(); col[i] = 0; } return (2 * ret); } namespace cent{ vector<int> dfn; int siz[MAXN], msz[MAXN]; void dfs2(int x, int p){ dfn.push_back(x); siz[x] = 1; msz[x] = 0; for(auto &fuck : gph[1][x]){ int u = fuck.first; int v = fuck.second; if(v != p && !vis[v]){ dfs2(v, x); siz[x] += siz[v]; msz[x] = max(msz[x], siz[v]); } } } int solve(int x){ dfn.clear(); dfs2(x, -1); pi dap(1e9, -1); for(auto &v : dfn){ int ans = max(sz(dfn) - siz[v], msz[v]); dap = min(dap, pi(ans, v)); } return dap.second; } }; void dfs3(int x, int p, vector<int> &to_comp){ to_comp.push_back(x); for(auto &fuck : gph[1][x]){ int u = fuck.first; int v = fuck.second; if(v != p && !vis[v]){ w[v] = w[x] + u; dfs3(v, x, to_comp); } } } void solve(){ for(int i=0; i<=2*n; i++) spt[0][i] = pi(1e9,1e9); dfs1(1, -1); for(int i=1; i<18; i++){ for(int j=0; j<=2*n; j++){ spt[i][j] = spt[i-1][j]; if(j + (1<<(i-1)) <= 2*n){ spt[i][j] = min(spt[i][j], spt[i-1][j + (1<<(i-1))]); } } } queue<int> que; que.push(1); ull ans = 0; while(sz(que)){ int x = que.front(); que.pop(); x = cent::solve(x); vis[x] = 1; vector<int> to_comp = {x}; for(auto &fuck : gph[1][x]){ int u = fuck.first; int v = fuck.second; if(!vis[v]){ w[v] = u; vector<int> tmp; dfs3(v, x, tmp); que.push(v); ans += - tree_comp(tmp); for(auto &i : tmp) to_comp.push_back(i); } } ans += tree_comp(to_comp); for(auto &i : to_comp){ w[i] = 0; } to_comp.clear(); } printf("%lld\n", ans); } static char buf[1 << 19]; // size : any number geq than 1024 static int idx = 0; static int bytes = 0; static inline int _read() { if (!bytes || idx == bytes) { bytes = (int)fread(buf, sizeof(buf[0]), sizeof(buf), stdin); idx = 0; } return buf[idx++]; } static inline int _readInt() { int x = 0, s = 1; int c = _read(); while (c <= 32) c = _read(); if (c == '-') s = -1, c = _read(); while (c > 32) x = 10 * x + (c - '0'), c = _read(); if (s < 0) x = -x; return x; } int main(){ for(int i=1; i<2*MAXN; i++){ lg[i] = lg[i-1]; while((2 << lg[i]) <= i) lg[i]++; } // int tc = _readInt(); int tc = 1; while(tc--){ n = _readInt(); for(int i=0; i<2; i++){ for(int j=1; j<n; j++){ int s = _readInt(); int e = _readInt(); // int x = _readInt(); int x = 1; gph[i][s].emplace_back(x, e); gph[i][e].emplace_back(x, s); } } solve(); for(int i=0; i<=n; i++){ vis[i] = 0; gph[0][i].clear(); gph[1][i].clear(); } piv = 0; } }