import sys from sys import stdin from collections import defaultdict sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() idx = 0 N, K = int(input[idx]), int(input[idx+1]) idx +=2 edges = [[] for _ in range(N+1)] color_edges = defaultdict(list) for i in range(N-1): u = int(input[idx]) v = int(input[idx+1]) c = int(input[idx+2]) idx +=3 edges[u].append((v, c, i)) edges[v].append((u, c, i)) color_edges[c].append((u, v)) # Compute sc for each color sc = dict() for c in color_edges: parent = {} size = {} def find(u): if parent[u] != u: parent[u] = find(parent[u]) return parent[u] def union(u, v): u_root = find(u) v_root = find(v) if u_root == v_root: return if size[u_root] < size[v_root]: u_root, v_root = v_root, u_root parent[v_root] = u_root size[u_root] += size[v_root] for u, v in color_edges[c]: if u not in parent: parent[u] = u size[u] = 1 if v not in parent: parent[v] = v size[v] = 1 union(u, v) components = defaultdict(int) for u in parent: root = find(u) components[root] = size[root] total = 0 for s in components.values(): total += s * (s - 1) // 2 sc[c] = total # Collect all (c, d) pairs from edges around each node pairs = set() node_color = defaultdict(set) for u in range(1, N+1): colors = set() for _, c, _ in edges[u]: colors.add(c) colors = list(colors) for i in range(len(colors)): for j in range(i+1, len(colors)): c1 = colors[i] c2 = colors[j] if c1 < c2: pairs.add((c1, c2)) else: pairs.add((c2, c1)) for c in colors: if len(colors) >= 1: for c2 in colors: if c != c2: if c < c2: pairs.add((c, c2)) else: pairs.add((c2, c)) valid = 0 # Process each unique pair (c, d) in pairs, ensuring c < d processed = set() for (c1, c2) in pairs: if c1 == c2: continue a, b = min(c1, c2), max(c1, c2) if (a, b) in processed: continue processed.add((a, b)) combined_edges = [] if a in color_edges: combined_edges.extend(color_edges[a]) if b in color_edges: combined_edges.extend(color_edges[b]) parent = {} size = {} def find(u): if parent[u] != u: parent[u] = find(parent[u]) return parent[u] def union(u, v): u_root = find(u) v_root = find(v) if u_root == v_root: return if size[u_root] < size[v_root]: u_root, v_root = v_root, u_root parent[v_root] = u_root size[u_root] += size[v_root] for u, v in combined_edges: if u not in parent: parent[u] = u size[u] = 1 if v not in parent: parent[v] = v size[v] = 1 union(u, v) components = defaultdict(int) for node in parent: root = find(node) components[root] = size.get(root, 1) total_ab = sum(m * (m-1) // 2 for m in components.values()) sa = sc.get(a, 0) sb = sc.get(b, 0) valid += (total_ab - sa - sb) print(valid) if __name__ == '__main__': main()