N, K = map(int, input().split()) a = sorted(map(int, input().split())) d = sorted([n-m for n, m in zip(a[1:], a)]) print(sum(d[:-K+1]) if K>1 else sum(d))