N, M = map(int, input().split()) A = list(map(int, input().split())) A.sort() diffs = [(a-b)**2 for a, b in zip(A, A[1:])] answer = float("inf") total = 0 for i in range(M-1): total += diffs[i] answer = total for i in range(M-1, len(diffs)): total += diffs[i] total -= diffs[i-M+1] answer = min(total, answer) print(answer)