#include #include #include #include #include using namespace std; class edge { public: int to, col; edge() : to(-1), col(-1) {}; edge(int to_, int col_) : to(to_), col(col_) {}; }; long long solve(int N, vector > G) { vector > child(N); vector par(N); vector ord; function make_order = [&](int pos, int pre) { for (edge e : G[pos]) { if (e.to == pre) continue; par[e.to] = edge(pos, e.col); child[pos].push_back(e); make_order(e.to, pos); } ord.push_back(pos); }; make_order(0, -1); vector group(N, -1); vector > d(N); vector dsum(N); long long ans = 0; for (int i : ord) { sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) { return d[group[e1.to]].size() > d[group[e2.to]].size(); }); group[i] = (child[i].empty() ? i : group[child[i][0].to]); map, int> s; map subs; for (int j = 0; j < child[i].size(); ++j) { edge e = child[i][j]; ans += dsum[group[e.to]] - d[group[e.to]][-1]; if (j >= 1) { for (pair k : d[group[e.to]]) { int va = k.first, vb = e.col; if (va > vb) swap(va, vb); if (va == vb) va = -1; if (va != -1) { if (s.find(make_pair(va, vb)) != s.end()) { ans += 1LL * s[make_pair(va, vb)] * k.second; } if (s.find(make_pair(-1, va)) != s.end()) { ans += 1LL * s[make_pair(-1, va)] * k.second; } if (s.find(make_pair(-1, vb)) != s.end()) { ans += 1LL * s[make_pair(-1, vb)] * k.second; } } else { if (subs.find(-1) != subs.end()) { ans += 1LL * subs[-1] * k.second; } if (subs.find(vb) != subs.end()) { ans += 1LL * subs[vb] * k.second; } if (s.find(make_pair(-1, vb)) != s.end()) { ans -= 2LL * s[make_pair(-1, vb)] * k.second; } } } for (pair k : d[group[e.to]]) { int va = k.first, vb = e.col; if (va > vb) swap(va, vb); if (va == vb) va = -1; s[make_pair(va, vb)] += k.second; subs[va] += k.second; subs[vb] += k.second; } } } int mcol = (child[i].empty() ? -1 : child[i][0].col); for (pair, int> j : s) { int va = j.first.first, vb = j.first.second; if (va == mcol) va = -1; if (vb == mcol) vb = -1; if (va > vb) swap(va, vb); if (vb == -1) { ans += 1LL * (dsum[group[i]] - d[group[i]][-1]) * j.second; } else if (va == -1) { if (d[group[i]].find(vb) != d[group[i]].end()) { ans += 1LL * d[group[i]][vb] * j.second; } ans += 1LL * d[group[i]][-1] * j.second; } } int pcol = par[i].col; if (mcol != pcol && mcol != -1) { int val = 0, valp = 0; for (pair j : d[group[i]]) { if (j.first == -1) { val = j.second; } else if (j.first == pcol) { valp = j.second; } else { dsum[group[i]] -= j.second; } } d[group[i]].clear(); d[group[i]][mcol] += val + valp; } for (int j = 1; j < child[i].size(); ++j) { int col = child[i][j].col; for (pair k : d[group[child[i][j].to]]) { int va = col, vb = k.first; if (va == pcol) va = -1; if (vb == pcol) vb = -1; if (va > vb) swap(va, vb); if (va == vb) va = -1; if (va != -1) continue; dsum[group[i]] += k.second; d[group[i]][vb] += k.second; } } dsum[group[i]] += 1; d[group[i]][-1] += 1; } return ans; } #include long long solve_easy(int N, vector > G) { long long ans = 0; for (int i = 0; i < N; ++i) { for (int j = 0; j < i; ++j) { set s; function dfs = [&](int pos, int pre) { if (pos == i) return true; bool ok = false; for (edge e : G[pos]) { if (e.to == pre) continue; bool res = dfs(e.to, pos); if (res) { s.insert(e.col); ok = true; } } return ok; }; dfs(j, -1); if (s.size() == 2) { ++ans; } } } return ans; } #include #include mt19937 mt(2003012025); int rand_rng(int l, int r) { uniform_int_distribution p(l, r - 1); return p(mt); } string to_string(vector arr) { int n = arr.size(); string ans = "["; for (int i = 0; i < n; ++i) { if (i) ans += ", "; ans += to_string(arr[i]); } ans += "]"; return ans; } void random_gen() { const int samples = 10000; int N = 7, K = 3; for (int i = 1; i <= samples; ++i) { vector > G(N); vector parseq(N, -1), colseq(N, -1); for (int j = 1; j < N; ++j) { int par = rand_rng(0, j), col = rand_rng(0, K); G[j].push_back(edge(par, col)); G[par].push_back(edge(j, col)); parseq[j] = par; colseq[j] = col; } long long res1 = solve(N, G); long long res2 = solve_easy(N, G); if (res1 != res2) { cout << "Case #" << i << " / N = " << N << endl; cout << "par = " << to_string(parseq) << endl; cout << "col = " << to_string(colseq) << endl; cout << "Returns: " << res1 << endl; cout << "Answer: " << res2 << endl; cout << endl; } } } int main() { // random_gen(); cin.tie(0); ios_base::sync_with_stdio(false); int N, K; cin >> N >> K; vector > G(N); for (int i = 0; i < N - 1; ++i) { int a, b, col; cin >> a >> b >> col; --a, --b, --col; G[a].push_back(edge(b, col)); G[b].push_back(edge(a, col)); } long long ans = solve(N, G); cout << ans << endl; return 0; }