import java.util.*; public class Main { static long[] arr; static long[][] dp; public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); arr = new long[n]; for (int i = 0; i < n; i++) { arr[i] = sc.nextLong(); } Arrays.sort(arr); dp = new long[n][k]; System.out.println(dfw(n - 1, k - 1)); } static long dfw(int idx, int remain) { if (idx < 0) { return 0; } if (remain < 0) { return Long.MAX_VALUE / 100; } if (dp[idx][remain] == 0) { long target = arr[idx]; long min = Long.MAX_VALUE / 100; for (int i = idx; i >= remain; i--) { min = Math.min(min, dfw(i - 1, remain - 1) + target - arr[i]); } dp[idx][remain] = min; } return dp[idx][remain]; } }