import sys input=lambda: sys.stdin.readline().rstrip() n,k=map(int,input().split()) A=sorted([int(i) for i in input().split()]) print(sum(sorted([A[i]-A[i-1] for i in range(1,n)])[:n-k]))