import java.io.InputStream; import java.io.PrintWriter; import java.lang.reflect.Array; import java.math.BigDecimal; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.PriorityQueue; import java.util.Queue; import java.util.Scanner; import java.util.Set; import java.util.Stack; import java.util.TreeMap; import java.util.TreeSet; import static java.util.Comparator.*; public class Main { public static void main(String[] args) { PrintWriter out = new PrintWriter(System.out); Solver solver = new Solver(System.in, out); solver.solve(); out.close(); } } class Solver { Scanner sc; PrintWriter out; public Solver(InputStream in, PrintWriter out) { sc = new Scanner(in); this.out = out; } // ================================================================== public void solve() { int N = Integer.parseInt(sc.next()); int K = Integer.parseInt(sc.next()); long[] A = new long[N]; for (int i = 0; i < N; i++) { A[i] = Long.parseLong(sc.next()); } Arrays.sort(A); long[] diff = new long[N-1]; for (int i = 1; i < N; i++) { diff[i-1] = A[i] - A[i-1]; } Arrays.sort(diff); long ans = 0; for (int i = 0; i < (N - 1) - (K - 1); i++) { ans += diff[i]; } out.println(ans); } }