n,k = map(int,input().split()) A = list(map(int,input().split())) def calc(x): count = [] for a in A: count.append(abs(x-a)) count.sort() ans = 0 for i in range(n): if i < k: ans += count[i] else: ans -= count[i] return ans l = A[0]-1 r = A[-1]+1 while r > l+1: m1 = (r+2*l)//3 m2 = (2*r+l)//3 c1 = calc(m1) c2 = calc(m2) if c1 <= c2: r = m2 else: l = m2 ans = 10**20 print(min(calc(A[0]),calc(A[-1])))