N, M = map(int, input().split()) A = list(map(int, input().split())) A = sorted(A) s = [0] * N for i in range(N-1): s[i+1] = s[i] + (A[i+1] - A[i]) ** 2 ans = A[-1] ** 2 for i in range(N): if i+M-1 >= N: break ans = min(ans, s[i+M-1] - s[i]) # print(s) print(ans)