import java.util.Scanner; public class Main { public static void main(String[] args) { Scanner scan = new Scanner(System.in); long B = scan.nextLong(); int N = scan.nextInt(); long[] C = new long[N]; long sum = B; for(int i = 0; i < N; i++) { C[i] = scan.nextLong(); sum += C[i]; } scan.close(); if(N == 1) { System.out.println("0"); System.exit(0); } long r = sum / N; long l = 0; long min = Long.MAX_VALUE; while(r - l > 3) { // 幅 len = (r - l) / 3 // x1 = l + len = (2l + r) / 3 // x2 = x1 + len = (l + 2r) / 3 long x1 = (l * 2 + r) / 3; long x2 = (l + r * 2) / 3; long t1 = 0; long t2 = 0; for(int i = 0; i < N; i++) { t1 += Math.abs(x1 - C[i]); t2 += Math.abs(x2 - C[i]); } if(t1 > t2) { l = x1; }else { r = x2; } } for(long i = l; i <= r; i++) { long t = 0; for(int j = 0; j < N; j++) { t += Math.abs(i - C[j]); } min = Math.min(min, t); } System.out.println(min); } }