結果
問題 |
No.1002 Twotone
|
ユーザー |
![]() |
提出日時 | 2020-02-29 13:03:13 |
言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 2,133 ms / 5,000 ms |
コード長 | 3,273 bytes |
コンパイル時間 | 1,460 ms |
コンパイル使用メモリ | 123,616 KB |
実行使用メモリ | 57,984 KB |
最終ジャッジ日時 | 2024-10-13 19:56:48 |
合計ジャッジ時間 | 16,884 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 33 |
ソースコード
#include <iostream> #include <algorithm> #include <string> #include <vector> #include <cmath> #include <map> #include <queue> #include <iomanip> #include <set> #include <tuple> #define mkp make_pair #define mkt make_tuple #define rep(i,n) for(int i = 0; i < (n); ++i) using namespace std; typedef long long ll; const ll MOD=1e9+7; template<class T> void chmin(T &a,const T &b){if(a>b) a=b;} template<class T> void chmax(T &a,const T &b){if(a<b) a=b;} struct Centroid{ vector<vector<int>> g; vector<int> sz,dead; Centroid(){} Centroid(int V):sz(V,1),dead(V,0),g(V){} void initialize(int V){ sz.resize(V,1); dead.resize(V,0); g.resize(V); } void add_edge(int u,int v){ g[u].push_back(v); g[v].push_back(u); } int szdfs(int now,int par){ sz[now]=1; for(auto nex:g[now]){ if(nex==par||dead[nex]) continue; sz[now]+=szdfs(nex,now); } return sz[now]; } void findCentroid(int now,int par,int V,vector<int> &cens){ bool ok=true; for(auto nex:g[now]){ if(nex==par||dead[nex]) continue; findCentroid(nex,now,V,cens); if(sz[nex]>V/2) ok=false; } if(V-sz[now]>V/2) ok=false; if(ok) cens.push_back(now); } vector<int> build(int root){ vector<int> cens; szdfs(root,-1); findCentroid(root,-1,sz[root],cens); return cens; } void kill(int now){ dead[now]=1; } bool alive(int now){ return dead[now]==0; } }; struct Edge{ int to,col; Edge(int to,int col):to(to),col(col){} }; Centroid cent; int N,K; vector<vector<Edge>> g; map<pair<int,int>,int> twoCol; map<int,int> two,one; int oneCol; map<pair<int,int>,int> tcol; map<int,int> tmp,omp; int ocol; ll calc(int now,int par,int fir,int sec){ ll res=0; if(sec==-1){ res+=two[fir]-tmp[fir]; res+=(oneCol-ocol)-(one[fir]-omp[fir]); one[fir]++;omp[fir]++; oneCol++;ocol++; }else{ res+=twoCol[minmax(fir,sec)]-tcol[minmax(fir,sec)]+1; res+=one[fir]-omp[fir]; res+=one[sec]-omp[sec]; twoCol[minmax(fir,sec)]++;tcol[minmax(fir,sec)]++; two[fir]++;two[sec]++; tmp[fir]++;tmp[sec]++; } for(auto e:g[now]){ if(cent.alive(e.to)==false) continue; if(e.to==par) continue; if(sec==-1){ if(fir==e.col) res+=calc(e.to,now,fir,sec); else res+=calc(e.to,now,fir,e.col); }else{ if(fir==e.col||sec==e.col) res+=calc(e.to,now,fir,sec); } } return res; } ll solve(int now){ vector<int> cs=cent.build(now); int C=cs[0]; cent.kill(C); twoCol.clear();two.clear();one.clear(); oneCol=0; ll res=0; for(auto e:g[C]){ if(cent.alive(e.to)==false) continue; tcol.clear();tmp.clear();omp.clear(); ocol=0; res+=calc(e.to,C,e.col,-1); } for(auto e:g[C]){ if(cent.alive(e.to)==false) continue; res+=solve(e.to); } return res; } int main(){ cin.tie(0); ios::sync_with_stdio(false); cin>>N>>K; cent.initialize(N); g.resize(N); rep(i,N-1){ int a,b,c; cin>>a>>b>>c; a--;b--; g[a].push_back({b,c}); g[b].push_back({a,c}); cent.add_edge(a,b); } cout<<solve(0)<<endl; return 0; }