from collections import * N, M = map(int, input().split()) A = sorted(list(map(int, input().split()))) Q = deque() now = 0 for i in range(M): Q.append(A[i]) if len(Q) >= 2: now += (Q[-1] - Q[-2])**2 ans = now for i in range(M, N): Q.append(A[i]) now += (Q[-1] - Q[-2]) ** 2 now -= (Q[0] - Q[1]) ** 2 Q.popleft() ans = min(ans, now) print(ans)