import sys input = sys.stdin.readline N,M=map(int,input().split()) A=[list(map(int,input().split())) for i in range(N)] V=[0]*M for i in range(N): for j in range(M): V[j]+=A[i][j] T=[0]*N for i in range(N): X=0 for j in range(M): X+=A[i][j]*V[j] T[i]=X T.sort(reverse=True) ANS=0 for i in range(len(T)): if i%2==0: ANS+=T[i] ANS*=2 for i in range(M): ANS-=V[i]**2 print(ANS)