import sys sys.setrecursionlimit(4 * 10**5) def solve(): N, K = map(int, sys.stdin.readline().split()) adj = [[] for _ in range(N)] for _ in range(N - 1): u, v, c = map(int, sys.stdin.readline().split()) u -= 1 v -= 1 adj[u].append((v, c)) adj[v].append((u, c)) total_ans = 0 is_removed_node = [False] * N subtree_sizes_for_centroid_finding = [0] * N # Used by get_subtree_sizes_and_find_centroid def calculate_contribution(path_summary): current_ans = 0 val_n0 = path_summary.get(frozenset(), 0) n1_items = [] n2_items = [] for S, count in path_summary.items(): if not S: continue if len(S) == 1: n1_items.append((list(S)[0], count)) elif len(S) == 2: n2_items.append((S, count)) if val_n0 > 0: for _, count_ab in n2_items: current_ans += val_n0 * count_ab sum_counts_n1 = 0 sum_sq_counts_n1 = 0 for _, count in n1_items: sum_counts_n1 += count sum_sq_counts_n1 += count * count if n1_items: current_ans += (sum_counts_n1 * sum_counts_n1 - sum_sq_counts_n1) // 2 if n1_items and n2_items: sum_n2_for_color = {} for S_xy, count_xy in n2_items: c_list = list(S_xy) # S_xy is frozenset of 2 colors c1, c2 = c_list[0], c_list[1] sum_n2_for_color[c1] = sum_n2_for_color.get(c1, 0) + count_xy sum_n2_for_color[c2] = sum_n2_for_color.get(c2, 0) + count_xy for ca, count_a in n1_items: current_ans += count_a * sum_n2_for_color.get(ca, 0) for _, count_ab in n2_items: current_ans += count_ab * (count_ab - 1) // 2 return current_ans _collected_paths_map_storage = {} # Reused to avoid repeated dict creation overhead def dfs_collect_paths(u, p, current_path_colors_from_start_node, target_map): target_map[current_path_colors_from_start_node] = \ target_map.get(current_path_colors_from_start_node, 0) + 1 for v, color_uv in adj[u]: if v == p or is_removed_node[v]: continue new_path_colors = current_path_colors_from_start_node.union(frozenset([color_uv])) if len(new_path_colors) <= 2: dfs_collect_paths(v, u, new_path_colors, target_map) _component_nodes_bfs_q = [0] * N # Reusable queue for BFS def get_component_nodes_and_size(entry_node): # BFS to find all nodes in the current component q_ptr = 0 _component_nodes_bfs_q[q_ptr] = entry_node q_ptr += 1 head = 0 visited_in_bfs = {entry_node} while head < q_ptr: curr = _component_nodes_bfs_q[head] head += 1 for neighbor, _ in adj[curr]: if not is_removed_node[neighbor] and neighbor not in visited_in_bfs: visited_in_bfs.add(neighbor) _component_nodes_bfs_q[q_ptr] = neighbor q_ptr += 1 # Nodes are _component_nodes_bfs_q[0...q_ptr-1] return q_ptr # This is component_size # DFS pass to calculate subtree sizes for centroid finding def _dfs_calc_sizes_for_centroid(u, p): subtree_sizes_for_centroid_finding[u] = 1 for v, _ in adj[u]: if v == p or is_removed_node[v]: continue _dfs_calc_sizes_for_centroid(v, u) subtree_sizes_for_centroid_finding[u] += subtree_sizes_for_centroid_finding[v] # DFS pass to find the centroid def _dfs_find_centroid(u, p, component_total_size): for v, _ in adj[u]: if v == p or is_removed_node[v]: continue if subtree_sizes_for_centroid_finding[v] * 2 > component_total_size: return _dfs_find_centroid(v, u, component_total_size) return u # u is the centroid def cd_solve(entry_node_in_component): nonlocal total_ans component_size = get_component_nodes_and_size(entry_node_in_component) _dfs_calc_sizes_for_centroid(entry_node_in_component, -1) centroid = _dfs_find_centroid(entry_node_in_component, -1, component_size) _collected_paths_map_storage.clear() dfs_collect_paths(centroid, -1, frozenset(), _collected_paths_map_storage) total_ans += calculate_contribution(_collected_paths_map_storage) is_removed_node[centroid] = True for v_neighbor, color_cv, in adj[centroid]: if is_removed_node[v_neighbor]: continue _collected_paths_map_storage.clear() dfs_collect_paths(v_neighbor, centroid, frozenset(), _collected_paths_map_storage) # Paths from v_neighbor paths_to_centroid_via_v_neighbor = {} for s_vx, count in _collected_paths_map_storage.items(): # s_vx is path v_neighbor to x s_cx = s_vx.union(frozenset([color_cv])) # s_cx is path centroid to x (via v_neighbor) if len(s_cx) <= 2: paths_to_centroid_via_v_neighbor[s_cx] = \ paths_to_centroid_via_v_neighbor.get(s_cx, 0) + count total_ans -= calculate_contribution(paths_to_centroid_via_v_neighbor) for v_neighbor, _, in adj[centroid]: if not is_removed_node[v_neighbor]: cd_solve(v_neighbor) if N > 0: cd_solve(0) print(total_ans) solve()