n,k = map(int,input().split()) *a, = map(int,input().split()) a.sort() b = [i-j for i,j in zip(a[1:],a)] b.sort() print(sum(b[:n-k]))