def main(): import sys input = sys.stdin.read().split() idx = 0 N = int(input[idx]) idx +=1 M = int(input[idx]) idx +=1 edges = [] for _ in range(M): a = int(input[idx])-1 idx +=1 b = int(input[idx])-1 idx +=1 w = int(input[idx]) idx +=1 edges.append((w, a, b)) edges.sort() parent = list(range(N)) size = [1]*N def find(u): while parent[u] != u: parent[u] = parent[parent[u]] u = parent[u] return u sum_total = 0 for w, a, b in edges: root_a = find(a) root_b = find(b) if root_a != root_b: sum_total += w * size[root_a] * size[root_b] if size[root_a] < size[root_b]: root_a, root_b = root_b, root_a parent[root_b] = root_a size[root_a] += size[root_b] print(sum_total) if __name__ == "__main__": main()