from functools import reduce add = lambda x,y: [v+w for v,w in zip(x,y)] n,m = map(int,input().split()) a = [list(map(int,input().split())) for _ in range(n)] r = reduce(add,a) p = [sum(rj*aij for rj,aij in zip(r,ai)) for i,ai in enumerate(a)] order = sorted(range(n),key = lambda i:p[i],reverse=1) x = reduce(add,[a[idx] for idx in order[::2]]) print(sum(xi**2 for xi in x) - sum((ri-xi)**2 for ri,xi in zip(r,x)))