import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(2*10**5+10) write = lambda x: sys.stdout.write(x+"\n") debug = lambda x: sys.stderr.write(x+"\n") n,m = list(map(int, input().split())) a = [list(map(int, input().split())) for _ in range(n)] s = [sum(a[i][j] for i in range(n)) for j in range(m)] vals = [0]*n for i in range(n): for j in range(m): vals[i] += s[j]*a[i][j]*2 vals.sort(reverse=1) v = sum(vals[::2]) ans = v - sum(val**2 for val in s) print(ans)