import sys ln=list(map(int,input().split())) N=list(map(int,input().split())) rem=[] if ln[1]==1: print(max(N)-min(N)) sys.exit() elif ln[0]==ln[1]: print(0) sys.exit() N.sort() for i in range(0,ln[0]-1): rem.append(N[i+1]-N[i]) ans=0 for i in range(0,ln[0]-ln[1]): n=min(rem) ans+=n rem.remove(n) print(ans)