# coding: utf-8 # Your code here! n,k = map(int,input().split()) lst = list(map(int,input().split())) lst.sort() indx_lst = [] if n==k: print(0) else: for i in range(n-1): x = lst[i+1] - lst[i] indx_lst.append([x,i]) indx_lst.sort() ans_lst = [] for i in range(1,k): ans_lst.append(indx_lst[-i][1]) ans_lst.sort() ans_lst.append(n-1) indx = 0 total = 0 for i in range(k): total += lst[ans_lst[i]] - lst[indx] indx = ans_lst[i] + 1 print(total)