import numpy as np N, M = map(int, input().split()) A = [list(map(int, input().split())) for _ in range(N)] A = np.array(A) me = np.zeros(M) E = np.zeros(M) a = np.sum(A, axis=1) for i in range(int(N/2)): max_A = np.argmax(a) me += A[max_A] A[max_A] = 0 a[max_A] = 0 max_A = np.argmax(a) E += A[max_A] A[max_A] = 0 a[max_A] = 0 me = np.power(me, 2) E = np.power(E, 2) result = np.sum(me) - np.sum(E) print(int(result))