B, M, *C = map(int, open(0).read().split()) l = min(C) r = (sum(C) + B) // M for _ in range(100): nl = (l * 2 + r) // 3 nr = (l + r * 2 + 2) // 3 a = sum(abs(c - nl) for c in C) b = sum(abs(c - nr) for c in C) if a < b: r = nr else: l = nl print(min(sum(abs(c - l) for c in C), sum(abs(c - r) for c in C)))