結果
| 問題 |
No.3194 Do Optimize Your Solution
|
| コンテスト | |
| ユーザー |
noya2
|
| 提出日時 | 2025-06-24 19:49:14 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 4,762 bytes |
| コンパイル時間 | 3,574 ms |
| コンパイル使用メモリ | 301,004 KB |
| 実行使用メモリ | 125,792 KB |
| 最終ジャッジ日時 | 2025-06-27 20:52:44 |
| 合計ジャッジ時間 | 10,887 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 4 WA * 3 TLE * 5 -- * 5 |
ソースコード
#include <bits/stdc++.h>
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
#define ws fuckinfoasdjfosadjljo
using namespace std;
using lint = long long;
using ull = unsigned long long;
using pi = pair<int, int>;
const int MAXN = 200005;
// const int mod = 1e9 + 7;
int n;
vector<pi> gph[2][MAXN];
int lvl[MAXN], lg[MAXN * 2], dep[MAXN], din[MAXN], dout[MAXN], piv;
bool vis[MAXN];
pi spt[19][MAXN * 2];
bool in(int x, int y){ return din[x] <= din[y] && dout[y] <= dout[x]; }
void dfs1(int x, int p){
din[x] = piv++;
if(p > 0) spt[0][piv - 1] = pi(lvl[p], p);
for(auto &fuck : gph[0][x]){
int u = fuck.first;
int v = fuck.second;
if(v != p){
dep[v] = dep[x] + u;
lvl[v] = lvl[x] + 1;
dfs1(v, x);
}
}
dout[x] = piv++;
if(p > 0) spt[0][piv - 1] = pi(lvl[p], p);
}
int lca(int x, int y){
if(din[x] > din[y]) swap(x, y);
if(in(x, y)) return x;
int s = dout[x], e = din[y];
int l = lg[e - s + 1];
return min(spt[l][s], spt[l][e - (1<<l) + 1]).second;
}
int col[MAXN], w[MAXN], pae[MAXN];
int cs[MAXN], ws[MAXN];
vector<pi> cmp[MAXN];
vector<int> ord;
void dfs4(int x, int p){
ord.push_back(x);
cs[x] = col[x];
ws[x] = col[x] ?w[x] : 0;
for(auto &fuck : cmp[x]){
int u = fuck.first;
int v = fuck.second;
if(v != p){
pae[v] = u;
dfs4(v, x);
cs[x] += cs[v];
ws[x] += ws[v];
}
}
}
ull tree_comp(vector<int> v){
for(auto &i : v) col[i] = 1;
sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; });
for(int i=sz(v)-1; i>0; i--) v.push_back(lca(v[i-1], v[i]));
sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; });
v.resize(unique(all(v)) - v.begin());
{
vector<int> stk;
for(auto &i : v){
while(sz(stk) && !in(stk.back(), i)) stk.pop_back();
if(sz(stk)){
int p = stk.back();
int dist = (dep[i] - dep[p]);
cmp[p].emplace_back(dist, i);
cmp[i].emplace_back(dist, p);
}
stk.push_back(i);
}
}
dfs4(v[0], -1);
int sumC = cs[v[0]];
int sumW = ws[v[0]];
ull ret = 0;
for(auto &i : ord){
ret += (1ull * ws[i] * (sumC - cs[i])) * pae[i];
ret += (1ull * cs[i] * (sumW - ws[i])) * pae[i];
}
ord.clear();
for(auto &i : v){
cmp[i].clear();
col[i] = 0;
}
return (2 * ret);
}
namespace cent{
vector<int> dfn;
int siz[MAXN], msz[MAXN];
void dfs2(int x, int p){
dfn.push_back(x);
siz[x] = 1; msz[x] = 0;
for(auto &fuck : gph[1][x]){
int u = fuck.first;
int v = fuck.second;
if(v != p && !vis[v]){
dfs2(v, x);
siz[x] += siz[v];
msz[x] = max(msz[x], siz[v]);
}
}
}
int solve(int x){
dfn.clear();
dfs2(x, -1);
pi dap(1e9, -1);
for(auto &v : dfn){
int ans = max(sz(dfn) - siz[v], msz[v]);
dap = min(dap, pi(ans, v));
}
return dap.second;
}
};
void dfs3(int x, int p, vector<int> &to_comp){
to_comp.push_back(x);
for(auto &fuck : gph[1][x]){
int u = fuck.first;
int v = fuck.second;
if(v != p && !vis[v]){
w[v] = w[x] + u;
dfs3(v, x, to_comp);
}
}
}
void solve(){
for(int i=0; i<=2*n; i++) spt[0][i] = pi(1e9,1e9);
dfs1(1, -1);
for(int i=1; i<19; i++){
for(int j=0; j<=2*n; j++){
spt[i][j] = spt[i-1][j];
if(j + (1<<(i-1)) <= 2*n){
spt[i][j] = min(spt[i][j], spt[i-1][j + (1<<(i-1))]);
}
}
}
queue<int> que;
que.push(1);
ull ans = 0;
while(sz(que)){
int x = que.front(); que.pop();
x = cent::solve(x);
vis[x] = 1;
vector<int> to_comp = {x};
for(auto &fuck : gph[1][x]){
int u = fuck.first;
int v = fuck.second;
if(!vis[v]){
w[v] = u;
vector<int> tmp;
dfs3(v, x, tmp);
que.push(v);
ans += - tree_comp(tmp);
for(auto &i : tmp) to_comp.push_back(i);
}
}
ans += tree_comp(to_comp);
for(auto &i : to_comp){
w[i] = 0;
}
to_comp.clear();
}
printf("%lld\n", ans);
}
static char buf[1 << 19]; // size : any number geq than 1024
static int idx = 0;
static int bytes = 0;
static inline int _read() {
if (!bytes || idx == bytes) {
bytes = (int)fread(buf, sizeof(buf[0]), sizeof(buf), stdin);
idx = 0;
}
return buf[idx++];
}
static inline int _readInt() {
int x = 0, s = 1;
int c = _read();
while (c <= 32) c = _read();
if (c == '-') s = -1, c = _read();
while (c > 32) x = 10 * x + (c - '0'), c = _read();
if (s < 0) x = -x;
return x;
}
int main(){
for(int i=1; i<2*MAXN; i++){
lg[i] = lg[i-1];
while((2 << lg[i]) <= i) lg[i]++;
}
// int tc = _readInt();
int tc = 1;
while(tc--){
n = _readInt();
for(int i=0; i<2; i++){
for(int j=1; j<n; j++){
int s = _readInt();
int e = _readInt();
// int x = _readInt();
int x = 1;
gph[i][s].emplace_back(x, e);
gph[i][e].emplace_back(x, s);
}
}
solve();
for(int i=0; i<=n; i++){
vis[i] = 0;
gph[0][i].clear();
gph[1][i].clear();
}
piv = 0;
}
}
noya2