import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); long total = sc.nextInt(); int n = sc.nextInt(); int[] arr = new int[n]; for (int i = 0; i < n; i++) { arr[i] = sc.nextInt(); total += arr[i]; } long left = 0; long right = total / n; boolean isRight = false; while (right - left > 2) { long m1 = (left * 2 + right) / 3; long m2 = (left + right * 2) / 3; if (calc((int)m1, arr) <= calc((int)m2, arr)) { right = m2; isRight = false; } else { isRight = true; left = m1; } } System.out.println(Math.min(Math.min(calc((int)right, arr), calc((int)left, arr)), Math.min(calc((int)((left * 2 + right) / 3), arr), calc((int)((left + right * 2) / 3), arr)))); } static int calc(int x, int[] arr) { int ans = 0; for (int y : arr) { ans += Math.abs(x - y); } return ans; } }