import sys ln=list(map(int,input().split())) N=list(map(int,input().split())) M=[] rem=[] if ln[1]==1: print(max(N)-min(N)) sys.exit() for i in range(0,ln[0]): n=min(N) M.append(n) N.remove(n) for i in range(0,ln[0]-1): rem.append(M[i+1]-M[i]) ans=0 for i in range(0,ln[0]-ln[1]): n=min(rem) ans+=n rem.remove(n) print(ans)