結果
問題 | No.1002 Twotone |
ユーザー |
![]() |
提出日時 | 2020-03-01 20:56:27 |
言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 302 ms / 5,000 ms |
コード長 | 5,450 bytes |
コンパイル時間 | 1,998 ms |
コンパイル使用メモリ | 144,444 KB |
実行使用メモリ | 71,044 KB |
最終ジャッジ日時 | 2024-10-13 20:52:19 |
合計ジャッジ時間 | 9,129 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 33 |
ソースコード
#include <map>#include <vector>#include <iostream>#include <algorithm>#include <functional>using namespace std;class edge {public:int to, col;edge() : to(-1), col(-1) {};edge(int to_, int col_) : to(to_), col(col_) {};};long long solve(int N, vector<vector<edge> > G) {vector<vector<edge> > child(N); vector<edge> par(N);vector<int> ord;function<void(int, int)> make_order = [&](int pos, int pre) {for (edge e : G[pos]) {if (e.to == pre) continue;par[e.to] = edge(pos, e.col);child[pos].push_back(e);make_order(e.to, pos);}ord.push_back(pos);};make_order(0, -1);vector<int> group(N, -1);vector<map<int, int> > d(N);vector<int> dsum(N);long long ans = 0;for (int i : ord) {sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) {return d[group[e1.to]].size() > d[group[e2.to]].size();});group[i] = (child[i].empty() ? i : group[child[i][0].to]);map<pair<int, int>, int> s;map<int, int> subs;for (int j = 0; j < child[i].size(); ++j) {edge e = child[i][j];ans += dsum[group[e.to]] - d[group[e.to]][-1];if (j >= 1) {for (pair<int, int> k : d[group[e.to]]) {int va = k.first, vb = e.col;if (va > vb) swap(va, vb);if (va == vb) va = -1;if (va != -1) {if (s.find(make_pair(va, vb)) != s.end()) {ans += 1LL * s[make_pair(va, vb)] * k.second;}if (s.find(make_pair(-1, va)) != s.end()) {ans += 1LL * s[make_pair(-1, va)] * k.second;}if (s.find(make_pair(-1, vb)) != s.end()) {ans += 1LL * s[make_pair(-1, vb)] * k.second;}}else {if (subs.find(-1) != subs.end()) {ans += 1LL * subs[-1] * k.second;}if (subs.find(vb) != subs.end()) {ans += 1LL * subs[vb] * k.second;}if (s.find(make_pair(-1, vb)) != s.end()) {ans -= 2LL * s[make_pair(-1, vb)] * k.second;}}}for (pair<int, int> k : d[group[e.to]]) {int va = k.first, vb = e.col;if (va > vb) swap(va, vb);if (va == vb) va = -1;s[make_pair(va, vb)] += k.second;subs[va] += k.second;subs[vb] += k.second;}}}int mcol = (child[i].empty() ? -1 : child[i][0].col);for (pair<pair<int, int>, int> j : s) {int va = j.first.first, vb = j.first.second;if (va == mcol) va = -1;if (vb == mcol) vb = -1;if (va > vb) swap(va, vb);if (vb == -1) {ans += 1LL * (dsum[group[i]] - d[group[i]][-1]) * j.second;}else if (va == -1) {if (d[group[i]].find(vb) != d[group[i]].end()) {ans += 1LL * d[group[i]][vb] * j.second;}ans += 1LL * d[group[i]][-1] * j.second;}}int pcol = par[i].col;if (mcol != pcol && mcol != -1) {int val = 0, valp = 0;for (pair<int, int> j : d[group[i]]) {if (j.first == -1) {val = j.second;}else if (j.first == pcol) {valp = j.second;}else {dsum[group[i]] -= j.second;}}d[group[i]].clear();d[group[i]][mcol] += val + valp;}for (int j = 1; j < child[i].size(); ++j) {int col = child[i][j].col;for (pair<int, int> k : d[group[child[i][j].to]]) {int va = col, vb = k.first;if (va == pcol) va = -1;if (vb == pcol) vb = -1;if (va > vb) swap(va, vb);if (va == vb) va = -1;if (va != -1) continue;dsum[group[i]] += k.second;d[group[i]][vb] += k.second;}}dsum[group[i]] += 1;d[group[i]][-1] += 1;}return ans;}#include <set>long long solve_easy(int N, vector<vector<edge> > G) {long long ans = 0;for (int i = 0; i < N; ++i) {for (int j = 0; j < i; ++j) {set<int> s;function<bool(int, int)> dfs = [&](int pos, int pre) {if (pos == i) return true;bool ok = false;for (edge e : G[pos]) {if (e.to == pre) continue;bool res = dfs(e.to, pos);if (res) {s.insert(e.col);ok = true;}}return ok;};dfs(j, -1);if (s.size() == 2) {++ans;}}}return ans;}#include <random>#include <string>mt19937 mt(2003012025);int rand_rng(int l, int r) {uniform_int_distribution<int> p(l, r - 1);return p(mt);}string to_string(vector<int> arr) {int n = arr.size();string ans = "[";for (int i = 0; i < n; ++i) {if (i) ans += ", ";ans += to_string(arr[i]);}ans += "]";return ans;}void random_gen() {const int samples = 10000;int N = 7, K = 3;for (int i = 1; i <= samples; ++i) {vector<vector<edge> > G(N);vector<int> parseq(N, -1), colseq(N, -1);for (int j = 1; j < N; ++j) {int par = rand_rng(0, j), col = rand_rng(0, K);G[j].push_back(edge(par, col));G[par].push_back(edge(j, col));parseq[j] = par;colseq[j] = col;}long long res1 = solve(N, G);long long res2 = solve_easy(N, G);if (res1 != res2) {cout << "Case #" << i << " / N = " << N << endl;cout << "par = " << to_string(parseq) << endl;cout << "col = " << to_string(colseq) << endl;cout << "Returns: " << res1 << endl;cout << "Answer: " << res2 << endl;cout << endl;}}}int main() {// random_gen();cin.tie(0);ios_base::sync_with_stdio(false);int N, K;cin >> N >> K;vector<vector<edge> > G(N);for (int i = 0; i < N - 1; ++i) {int a, b, col;cin >> a >> b >> col; --a, --b, --col;G[a].push_back(edge(b, col));G[b].push_back(edge(a, col));}long long ans = solve(N, G);cout << ans << endl;return 0;}