import sys input = sys.stdin.readline n, m = map(int, input().split()) A = list(map(int, input().split())) A.sort() ans = 10**18 res = 0 for i in range(1, n): res += (A[i] - A[i - 1])**2 if i >= m - 1: ans = min(ans, res) res -= (A[i - m + 2] - A[i - m + 1])**2 print(ans)