n, k = map(int, input().split()) g: list[list[tuple[int, int, int]]] = [[] for _ in range(n)] for eid in range(n - 1): u, v, c = map(int, input().split()) u -= 1 v -= 1 g[u].append((c, eid, v)) g[v].append((c, eid, u)) uf = [-1] * (n - 1) def size(x: int) -> int: return -uf[find(x)] def find(x: int) -> int: if uf[x] < 0: return x uf[x] = find(uf[x]) return uf[x] def merge(x: int, y: int) -> None: x = find(x) y = find(y) if x == y: return if uf[x] > uf[y]: x, y = y, x uf[x] += uf[y] uf[y] = x for i in range(n): g[i].sort() siz = len(g[i]) for j in range(siz - 1): c1, eid1, v1 = g[i][j] c2, eid2, v2 = g[i][j + 1] if c1 == c2: merge(eid1, eid2) ans = 0 for i in range(n): siz = len(g[i]) cmps: list[int] = [] for j in range(siz): if j == 0 or g[i][j][0] != g[i][j - 1][0]: cmps.append(size(g[i][j][1])) ans += (sum(cmps) ** 2 - sum(e ** 2 for e in cmps)) // 2 print(ans)