n,k = map(int, input().split()) a = list(map(int, input().split())) a.sort() d=[] for i in range(n-1): d.append(a[i+1]-a[i]) d.sort() if k==1: print(sum(d)) else: print(sum(d[:-k+1]))