N,M = map(int,input().split()) A = [list(map(int,input().split())) for i in range(N)] S = [sum(A[i][j] for i in range(N)) for j in range(M)] score = [sum(A[i][j]*S[j] for j in range(M)) for i in range(N)] box = [i for i in range(N)] box.sort(key = lambda x:-score[x]) res = [0 for i in range(M)] for i in range(0,N,2): idx = box[i] for j in range(M): res[j] += A[idx][j] ans = sum(res[j]**2-(S[j]-res[j])**2 for j in range(M)) print(ans)