n, m = map(int, input().split()) A = list(map(int, input().split())) A.sort() ans = [] for i in range(n-1): a = A[i+1] - A[i] ans.append(a * a) ans.sort() answer = [] for i in range(m-1): answer.append(ans[i]) print(sum(answer))