N, M = map(int, input().split())
A = list(map(int, input().split()))
A.sort()

diffs = [(a-b)**2 for a, b in zip(A, A[1:])]

answer = float("inf")
total = 0

for i in range(M-1):
    total += diffs[i]

answer = total

for i in range(M-1, len(diffs)):
    total += diffs[i]
    total -= diffs[i-M+1]
    answer = min(total, answer)

print(answer)