n, m = map(int, input().split()) A = list(map(int, input().split())) A.sort() ans = 0 for i in range(m-1): ans += (A[i+1] - A[i]) ** 2 answer = ans for i in range(n-m): ans -= (A[i+1] - A[i]) ** 2 ans += (A[i+m] - A[i+m-1]) ** 2 answer = min (answer, ans) print(answer)