n,m = map(int,input().split()) a = [list(map(int,input().split())) for _ in range(n)] r = [0]*m for ai in a: for i,v in enumerate(ai): r[i] += v p = [sum(rj*aij for rj,aij in zip(r,ai)) for i,ai in enumerate(a)] order = sorted(range(n),key = lambda i:p[i],reverse=1) x = [0]*m for idx in order[::2]: for i,v in enumerate(a[idx]): x[i] += v print(sum(xi**2 for xi in x) - sum((ri-xi)**2 for ri,xi in zip(r,x)))