def main(): N, K = map(int, input().split()) if N == K: print(0) exit() A = tuple(map(int, input().split())) if K == 1: print(max(A) - min(A)) exit() a = sorted(A) slice_ = [0] * (N - 1) for i in range(1, N): slice_[i-1] = (a[i] - a[i-1], i) slice_.sort(reverse=True) splt = sorted(slice_[:K-1], key=lambda x: x[1]) s = 0 total = 0 for _, e in splt: total += max(a[s:e]) - min(a[s:e]) s = e total += max(a[s:]) - min(a[s:]) print(total) main()