n, k = map(int, input().split()) a = sorted(map(int, input().split())) l = sorted([a[i+1] - a[i] for i in range(n-1)]) print(sum(l[:n-k]))