結果
問題 | No.1002 Twotone |
ユーザー |
![]() |
提出日時 | 2020-02-28 23:06:04 |
言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,460 ms / 5,000 ms |
コード長 | 7,640 bytes |
コンパイル時間 | 2,327 ms |
コンパイル使用メモリ | 160,332 KB |
実行使用メモリ | 59,648 KB |
最終ジャッジ日時 | 2024-10-13 18:59:11 |
合計ジャッジ時間 | 21,287 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 33 |
ソースコード
#include<iostream>#include<string>#include<cstdio>#include<vector>#include<cmath>#include<algorithm>#include<functional>#include<iomanip>#include<queue>#include<ciso646>#include<random>#include<map>#include<set>#include<bitset>#include<stack>#include<unordered_map>#include<utility>#include<cassert>#include<complex>#include<numeric>using namespace std;//#define int long longtypedef long long ll;typedef unsigned long long ul;typedef unsigned int ui;const ll mod = 1000000007;const ll INF = mod * mod;typedef pair<int, int>P;#define stop char nyaa;cin>>nyaa;#define rep(i,n) for(int i=0;i<n;i++)#define per(i,n) for(int i=n-1;i>=0;i--)#define Rep(i,sta,n) for(int i=sta;i<n;i++)#define rep1(i,n) for(int i=1;i<=n;i++)#define per1(i,n) for(int i=n;i>=1;i--)#define Rep1(i,sta,n) for(int i=sta;i<=n;i++)#define all(v) (v).begin(),(v).end()typedef pair<ll, ll> LP;typedef long double ld;typedef pair<ld, ld> LDP;const ld eps = 1e-12;const ld pi = acos(-1.0);//typedef vector<vector<ll>> mat;typedef vector<int> vec;ll mod_pow(ll a, ll n, ll m) {ll res = 1;while (n) {if (n & 1)res = res * a%m;a = a * a%m; n >>= 1;}return res;}struct modint {ll n;modint() :n(0) { ; }modint(ll m) :n(m) {if (n >= mod)n %= mod;else if (n < 0)n = (n%mod + mod) % mod;}operator int() { return n; }};bool operator==(modint a, modint b) { return a.n == b.n; }modint operator+=(modint &a, modint b) { a.n += b.n; if (a.n >= mod)a.n -= mod; return a; }modint operator-=(modint &a, modint b) { a.n -= b.n; if (a.n < 0)a.n += mod; return a; }modint operator*=(modint &a, modint b) { a.n = ((ll)a.n*b.n) % mod; return a; }modint operator+(modint a, modint b) { return a += b; }modint operator-(modint a, modint b) { return a -= b; }modint operator*(modint a, modint b) { return a *= b; }modint operator^(modint a, int n) {if (n == 0)return modint(1);modint res = (a*a) ^ (n / 2);if (n % 2)res = res * a;return res;}ll inv(ll a, ll p) {return (a == 1 ? 1 : (1 - p * inv(p%a, a)) / a + p);}modint operator/(modint a, modint b) { return a * modint(inv(b, mod)); }const int max_n = 1 << 18;modint fact[max_n], factinv[max_n];void init_f() {fact[0] = modint(1);for (int i = 0; i < max_n - 1; i++) {fact[i + 1] = fact[i] * modint(i + 1);}factinv[max_n - 1] = modint(1) / fact[max_n - 1];for (int i = max_n - 2; i >= 0; i--) {factinv[i] = factinv[i + 1] * modint(i + 1);}}modint comb(int a, int b) {if (a < 0 || b < 0 || a < b)return 0;return fact[a] * factinv[b] * factinv[a - b];}using mP = pair<modint, modint>;int dx[4] = { 0,1,0,-1 };int dy[4] = { 1,0,-1,0 };typedef modint Data;typedef vector<Data> Array;typedef vector<Array> mat;mat operator*(const mat &lhs, const mat &rhs) {mat ret(lhs.size(), Array(rhs[0].size(), 0));rep(i, lhs.size())rep(j, rhs[0].size())rep(k, rhs.size()) {ret[i][j] = lhs[i][k]*rhs[k][j]+ret[i][j];}return ret;}mat scalar(int sz, Data k) {mat ret(sz, Array(sz, 0));rep(i, sz)ret[i][i] = k;return ret;}mat operator^(const mat &lhs, const ll n) {if (n == 0)return scalar(lhs.size(), 1);mat ret = (lhs*lhs) ^ (n / 2);if (n % 2) {ret = ret * lhs;}return ret;}const int mn = 1 << 18;struct edge {int to, col;};vector<edge> G[mn];queue<vector<int>> q;bool exi[mn];ll ans = 0;void yaru(vector<int> v) {if (v.empty())return;//初期化for (int id : v)exi[id] = true;int g; int sz = v.size();function<int(int, int)> s_root = [&](int id, int fr)->int {int res = 1;int ma = 0;for (edge e : G[id]) {int to = e.to;if (to == fr)continue;if (!exi[to])continue;int nex = s_root(to, id);ma = max(ma, nex);res += nex;}if (ma <= sz / 2 && sz - res <= sz / 2)g = id;return res;};s_root(v[0], -1);//ここまで初期化map<P, int> alcnt;map<P, int> subcnt;ll c1 = 0;ll subc1 = 0;//重心を根としてなんかやるfunction<void(int, int,P)> dfs = [&](int id, int fr,P cols) {if (!exi[id])return;alcnt[cols]++;subcnt[cols]++;if (cols.first < 0) {c1++; subc1++;}for (edge e : G[id]) {int to = e.to;if (to == fr)continue;P nex = cols;if (e.col == nex.first || e.col == nex.second) {//}else if (nex.first == -1) {nex.first = e.col;if (nex.first > nex.second)swap(nex.first, nex.second);}else continue;dfs(to, id, nex);}};//cout << "root is " << g << endl;//cout << "start" << endl;ll csum = 0;for (edge e : G[g]) {subcnt.clear();subc1 = 0;P ori = { -1,e.col };dfs(e.to, g,ori);for (pair<P, int> p : subcnt) {P cur = p.first;ll num = p.second;if(cur.first>=0){csum -= num*subcnt[cur];csum -= 2*num*subcnt[{-1, cur.first}];csum -= 2*num*subcnt[{-1, cur.second}];}}csum -= subc1 * subc1;//cout << e.to << " ! " << csum << endl;}ll gcnt = 0;for (pair<P, int> p : alcnt) {P cur = p.first;ll num = p.second;gcnt += num;if (cur.first >= 0) {csum += num*alcnt[cur];csum += 2*num*alcnt[{-1, cur.first}];csum += 2*num*alcnt[{-1, cur.second}];}}csum += c1 * c1;ans += csum;ans += 2*gcnt;//cout << csum <<" ?? "<<gcnt<< endl;//cout << "end" << endl;//ここまでvector<vector<int>> chs;vector<int> nexs;function<void(int, int)> search_next = [&](int id, int fr) {if (!exi[id])return;nexs.push_back(id);for (edge e : G[id]) {int to = e.to;if (to == fr)continue;search_next(to, id);}};//子を列挙するfor (edge e : G[g]) {int to = e.to;search_next(to, g);if (nexs.empty())continue;q.push(nexs);chs.push_back(nexs);nexs.clear();}//子達についてなんかやるfor (int id : v)exi[id] = false;}void uoo(int n) {vector<int> ori(n); rep(i, n)ori[i] = i;q.push(ori);while (!q.empty()) {vector<int> v = q.front(); q.pop();yaru(v);}}struct uf {private:vector<int> par, ran,sz;public:uf(int n) {par.resize(n, 0);ran.resize(n, 0);sz.resize(n, 1);rep(i, n) {par[i] = i;}}int find(int x) {if (par[x] == x)return x;else return par[x] = find(par[x]);}void unite(int x, int y) {x = find(x), y = find(y);if (x == y)return;if (ran[x] < ran[y]) {par[x] = y; sz[y] += sz[x];}else {par[y] = x; sz[x] += sz[y];if (ran[x] == ran[y])ran[x]++;}}bool same(int x, int y) {return find(x) == find(y);}int comp(int x) { return sz[x]; }};vector<P> vs[1 << 18];void solve() {int n, k; cin >> n >> k;rep(i, n - 1) {int a, b, c; cin >> a >> b >> c; a--; b--;G[a].push_back({ b,c });G[b].push_back({ a,c });vs[c].push_back({ a,b });}uoo(n);ll num1 = 0;rep1(i, k) {if (vs[i].size() == 0)continue;vector<int> ids;for (P p : vs[i]) {ids.push_back(p.first);ids.push_back(p.second);}sort(all(ids));ids.erase(unique(all(ids)), ids.end());map<int, int> trans;rep(j, ids.size())trans[ids[j]] = j;rep(j, vs[i].size()) {vs[i][j].first = trans[vs[i][j].first];vs[i][j].second = trans[vs[i][j].second];}uf u(trans.size());for (P p : vs[i]) {u.unite(p.first, p.second);}vector<bool> used(trans.size(), false);rep(j, trans.size()) {int p = u.find(j);if (used[p])continue;used[p] = true;ll sz = u.comp(p);num1 += sz * (sz - 1);}}//cout << num1 << endl;ans -= num1;cout << ans/2 << endl;}signed main() {ios::sync_with_stdio(false);cin.tie(0);//cout << fixed << setprecision(10);init_f();//init();//int t; cin >> t; rep(i, t)solve();solve();stopreturn 0;}