n, m = map(int, input().split()) a = sorted([*map(int, input().split())]) d = sorted([v - u for u, v in zip(a, a[1:])]) ans = 0 for v in d[:m - 1]: ans += v * v print(ans)