n, m = map(int, input().split()) a = sorted([*map(int, input().split())]) ans, res = 10 ** 50, 0 for i, v in enumerate(a): if i > 0: res += (v - p) ** 2 if i >= m: res -= (a[i - m + 1] - a[i - m]) ** 2 if i >= m - 1: ans = min(ans, res) p = v print(ans)