n, m = map(int, input().split()) a = sorted(list(map(int, input().split()))) now = 0 for i in range(m-1): now += (a[i+1] - a[i]) ** 2 ans = now for i in range(n-m): now -= (a[i+1] - a[i]) ** 2 now += (a[i+m] - a[i+m-1]) ** 2 ans = min(ans, now) print(ans)