n, m = map(int, input().split()) A = list(map(int, input().split())) A.sort() if m == 1: print(0) exit() B = [(A[i + 1] - A[i]) ** 2 for i in range(n - 1)] m -= 1 tot = sum(B[:m]) ans = tot for i in range(m, n - 1): tot += B[i] - B[i - m] ans = min(ans, tot) print(ans)