#include #include #include #include #include using namespace std; class edge { public: int to, col; edge() : to(-1), col(-1) {}; edge(int to_, int col_) : to(to_), col(col_) {}; }; int main() { cin.tie(0); ios_base::sync_with_stdio(false); int N, K; cin >> N >> K; vector > G(N); for (int i = 0; i < N - 1; ++i) { int a, b, col; cin >> a >> b >> col; --a, --b, --col; G[a].push_back(edge(b, col)); G[b].push_back(edge(a, col)); } vector > child(N); vector par(N); vector ord; function make_order = [&](int pos, int pre) { for (edge e : G[pos]) { if (e.to == pre) continue; par[e.to] = edge(pos, e.col); child[pos].push_back(e); make_order(e.to, pos); } ord.push_back(pos); }; make_order(0, -1); vector group(N, -1); vector > d(N); vector dsum(N); long long ans = 0; for (int i : ord) { sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) { return d[group[e1.to]].size() > d[group[e2.to]].size(); }); group[i] = (child[i].empty() ? i : group[child[i][0].to]); map, int> s; map subs; for (int j = 0; j < child[i].size(); ++j) { edge e = child[i][j]; ans += dsum[group[e.to]] - d[group[e.to]][-1]; if (j >= 1) { for (pair k : d[group[e.to]]) { int va = k.first, vb = e.col; if (va > vb) swap(va, vb); if (va == vb) va = -1; if (va != -1) { if (s.find(make_pair(va, vb)) != s.end()) { ans += s[make_pair(va, vb)]; } if (s.find(make_pair(-1, va)) != s.end()) { ans += s[make_pair(-1, vb)]; } if (s.find(make_pair(-1, vb)) != s.end()) { ans += s[make_pair(-1, vb)]; } } else { if (subs.find(-1) != subs.end()) { ans += subs[-1]; } if (subs.find(vb) != subs.end()) { ans += subs[vb]; } if (s.find(make_pair(-1, vb)) != s.end()) { ans -= s[make_pair(-1, vb)] * 2; } } } for (pair k : d[group[e.to]]) { int va = k.first, vb = e.col; if (va > vb) swap(va, vb); if (va == vb) va = -1; s[make_pair(va, vb)] += k.second; subs[va] += k.second; subs[vb] += k.second; } } } int mcol = (child[i].empty() ? -1 : child[i][0].col); for (pair, int> j : s) { int va = j.first.first, vb = j.first.second; if (va == mcol) va = -1; if (vb == mcol) vb = -1; if (va > vb) swap(va, vb); if (vb == -1) { ans += dsum[group[i]] - d[group[i]][-1]; } else if (va == -1) { if (d[group[i]].find(vb) != d[group[i]].end()) { ans += d[group[i]][vb] * j.second; } ans += d[group[i]][-1] * j.second; } } int pcol = par[i].col; if (mcol != pcol && mcol != -1) { int val = 0, valp = 0; for (pair j : d[group[i]]) { if (j.first == -1) { val = j.second; } else if (j.first == pcol) { valp = j.second; } else { dsum[group[i]] -= j.second; } } d[group[i]].clear(); d[group[i]][mcol] += val + valp; } for (int j = 1; j < child[i].size(); ++j) { int col = child[i][j].col; for (pair k : d[child[i][j].to]) { int va = col, vb = k.first; if (va == pcol) va = -1; if (vb == pcol) vb = -1; if (va > vb) swap(va, vb); if (va == vb) va = -1; if (va != -1) continue; dsum[group[i]] += k.second; d[group[i]][vb] += k.second; } } dsum[group[i]] += 1; d[group[i]][-1] += 1; } cout << ans << endl; return 0; }