結果

問題 No.3194 Do Optimize Your Solution
ユーザー noya2
提出日時 2025-06-24 19:54:04
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 4,770 bytes
コンパイル時間 4,774 ms
コンパイル使用メモリ 300,740 KB
実行使用メモリ 125,340 KB
最終ジャッジ日時 2025-06-27 20:52:36
合計ジャッジ時間 11,000 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other WA * 1 TLE * 1 -- * 15
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
#define ws fuckinfoasdjfosadjljo
using namespace std;
using lint = long long;
using ull = unsigned long long;
using pi = pair<int, int>;
const int MAXN = 200005;
// const int mod = 1e9 + 7;

int n;
vector<pi> gph[2][MAXN];
int lvl[MAXN], lg[MAXN * 2], dep[MAXN], din[MAXN], dout[MAXN], piv;
bool vis[MAXN];
pi spt[19][MAXN * 2];
bool in(int x, int y){ return din[x] <= din[y] && dout[y] <= dout[x]; }

void dfs1(int x, int p){
	din[x] = piv++;
	if(p > 0) spt[0][piv - 1] = pi(lvl[p], p);
	for(auto &fuck : gph[0][x]){
		int u = fuck.first;
		int v = fuck.second;
		if(v != p){
			dep[v] = dep[x] + u;
			lvl[v] = lvl[x] + 1;
			dfs1(v, x);
		}
	}
	dout[x] = piv++;
	if(p > 0) spt[0][piv - 1] = pi(lvl[p], p);
}

int lca(int x, int y){
	if(din[x] > din[y]) swap(x, y);
	if(in(x, y)) return x;
	int s = dout[x], e = din[y];
	int l = lg[e - s + 1];
	return min(spt[l][s], spt[l][e - (1<<l) + 1]).second;
}

int col[MAXN], pae[MAXN];
ull w[MAXN];
int cs[MAXN];
ull ws[MAXN];
vector<pi> cmp[MAXN];
vector<int> ord;

void dfs4(int x, int p){
	ord.push_back(x);
	cs[x] = col[x];
	ws[x] = col[x] ?w[x] : 0;
	for(auto &fuck : cmp[x]){
		int u = fuck.first;
		int v = fuck.second;
		if(v != p){
			pae[v] = u;
			dfs4(v, x);
			cs[x] += cs[v];
			ws[x] += ws[v];
		}
	}
}

ull tree_comp(vector<int> v){
	for(auto &i : v) col[i] = 1;
	sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; });
	for(int i=sz(v)-1; i>0; i--) v.push_back(lca(v[i-1], v[i]));
	sort(all(v), [&](const int &x, const int &y){ return din[x] < din[y]; });
	v.resize(unique(all(v)) - v.begin());
	{
		vector<int> stk;
		for(auto &i : v){
			while(sz(stk) && !in(stk.back(), i)) stk.pop_back();
			if(sz(stk)){
				int p = stk.back();
				int dist = (dep[i] - dep[p]);
				cmp[p].emplace_back(dist, i);
				cmp[i].emplace_back(dist, p);
			}
			stk.push_back(i);
		}
	}
	dfs4(v[0], -1);
	int sumC = cs[v[0]];
	int sumW = ws[v[0]];
	ull ret = 0;
	for(auto &i : ord){
		ret += (1ull * ws[i] * (sumC - cs[i])) * pae[i];
		ret += (1ull * cs[i] * (sumW - ws[i])) * pae[i];
	}
	ord.clear();
	for(auto &i : v){
		cmp[i].clear();
		col[i] = 0;
	}
	return (2 * ret);
}

namespace cent{
	vector<int> dfn;
	int siz[MAXN], msz[MAXN];
	void dfs2(int x, int p){
		dfn.push_back(x);
		siz[x] = 1; msz[x] = 0;
	for(auto &fuck : gph[1][x]){
		int u = fuck.first;
		int v = fuck.second;
			if(v != p && !vis[v]){
				dfs2(v, x);
				siz[x] += siz[v];
				msz[x] = max(msz[x], siz[v]);
			}
		}
	}

	int solve(int x){
		dfn.clear();
		dfs2(x, -1);
		pi dap(1e9, -1);
		for(auto &v : dfn){
			int ans = max(sz(dfn) - siz[v], msz[v]);
			dap = min(dap, pi(ans, v));
		}
		return dap.second;
	}
};

void dfs3(int x, int p,  vector<int> &to_comp){
	to_comp.push_back(x);
	for(auto &fuck : gph[1][x]){
		int u = fuck.first;
		int v = fuck.second;
		if(v != p && !vis[v]){
			w[v] = w[x] + u;
			dfs3(v, x, to_comp);
		}
	}
}

void solve(){
	for(int i=0; i<=2*n; i++) spt[0][i] = pi(1e9,1e9);
	dfs1(1, -1);
	for(int i=1; i<19; i++){
		for(int j=0; j<=2*n; j++){
			spt[i][j] = spt[i-1][j];
			if(j + (1<<(i-1)) <= 2*n){
				spt[i][j] = min(spt[i][j], spt[i-1][j + (1<<(i-1))]);
			}
		}
	}
	queue<int> que;
	que.push(1);
	ull ans = 0;
	while(sz(que)){
		int x = que.front(); que.pop();
		x = cent::solve(x);
		vis[x] = 1;
		vector<int> to_comp = {x};
	for(auto &fuck : gph[1][x]){
		int u = fuck.first;
		int v = fuck.second;
			if(!vis[v]){
				w[v] = u;
				vector<int> tmp;
				dfs3(v, x, tmp);
				que.push(v);
				ans += - tree_comp(tmp);
				for(auto &i : tmp) to_comp.push_back(i);
			}
		}
		ans += tree_comp(to_comp);
		for(auto &i : to_comp){
			w[i] = 0;
		}
		to_comp.clear();
	}
	printf("%lld\n", ans);
}

static char buf[1 << 19]; // size : any number geq than 1024
static int idx = 0;
static int bytes = 0;
static inline int _read() {
	if (!bytes || idx == bytes) {
		bytes = (int)fread(buf, sizeof(buf[0]), sizeof(buf), stdin);
		idx = 0;
	}
	return buf[idx++];
}
static inline int _readInt() {
	int x = 0, s = 1;
	int c = _read();
	while (c <= 32) c = _read();
	if (c == '-') s = -1, c = _read();
	while (c > 32) x = 10 * x + (c - '0'), c = _read();
	if (s < 0) x = -x;
	return x;
}


int main(){
	for(int i=1; i<2*MAXN; i++){
		lg[i] = lg[i-1];
		while((2 << lg[i]) <= i) lg[i]++;
	}
	// int tc = _readInt();
    int tc = 1;
	while(tc--){
		n = _readInt();
		for(int i=0; i<2; i++){
			for(int j=1; j<n; j++){
				int s = _readInt();
				int e = _readInt();
				// int x = _readInt();
                int x = 1;
				gph[i][s].emplace_back(x, e);
				gph[i][e].emplace_back(x, s);
			}
		}
		solve();
		for(int i=0; i<=n; i++){
			vis[i] = 0;
			gph[0][i].clear();
			gph[1][i].clear();
		}
		piv = 0;
	}
}
0