import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); long[] arr = new long[n]; for (int i = 0; i < n; i++) { arr[i] = sc.nextLong(); } Arrays.sort(arr); long total = 0; for (int i = 0; i < n - 1; i++) { arr[i] = arr[i + 1] - arr[i]; total += arr[i]; } arr[n - 1] = 0; Arrays.sort(arr); for (int i = 0; i < k - 1; i++) { total -= arr[n - i - 1]; } System.out.println(total); } }