import sys input = sys.stdin.readline n, m = map(int, input().split()) A = list(map(int, input().split())) A.sort() D = [] for i in range(1, n): D.append(A[i] - A[i - 1]) D.sort() ans = 0 for i in range(m - 1): ans += D[i]**2 print(ans)