n, k = map(int, input().split()) a = list(map(int, input().split())) a.sort() prefix = [0] * (n + 1) for i in range(n): prefix[i + 1] = prefix[i] + a[i] if k == 0: sum_all = prefix[n] - prefix[0] sum_left = sum_all - a[0] * n sum_right = a[-1] * n - sum_all ans = -max(sum_left, sum_right) elif k == n: if n % 2 == 1: mid = n // 2 x = a[mid] sum_median = sum(abs(num - x) for num in a) ans = sum_median else: mid1 = (n // 2) - 1 mid2 = n // 2 x1 = a[mid1] x2 = a[mid2] sum1 = sum(abs(num - x1) for num in a) sum2 = sum(abs(num - x2) for num in a) ans = min(sum1, sum2) else: # Left case: first K elements sum_S_left = (prefix[k] - prefix[0]) - a[0] * k sum_non_S_left = (prefix[n] - prefix[k]) - a[0] * (n - k) left_val = sum_S_left - sum_non_S_left # Right case: last K elements sum_S_right = a[-1] * k - (prefix[n] - prefix[n - k]) sum_non_S_right = a[-1] * (n - k) - (prefix[n - k] - prefix[0]) right_val = sum_S_right - sum_non_S_right ans = min(left_val, right_val) print(ans)