N, M = map(int, input().split()) A = list(map(int, input().split())) A.sort() ans = 10 ** 18 sm = 0 for i in range(N): if i == 0: for j in range(M - 1): sm += (A[j + 1] - A[j]) ** 2 ans = sm continue if i + M - 1 >= N: continue sm -= (A[i] - A[i - 1]) ** 2 sm += (A[i + M - 1] - A[i + M - 2]) ** 2 ans = min(ans, sm) print(ans)