import sys input = sys.stdin.readline sys.setrecursionlimit(10 ** 6) N, M = map(int, input().split()) a = [list(map(int, input().split())) for _ in range(N)] vis = [0] * N resx = [0] * M resy = [0] * M def check(): res = 0 for i in range(M): res += resx[i] ** 2 - resy[i] ** 2 return res def solve(c): global vis, resx, resy mxi = -1 mx = c * 10 ** 19 for i in range(N): if vis[i]: continue for j in range(M): if c: resy[j] += a[i][j] else: resx[j] += a[i][j] t = check() if c: if mx > t: mx = t mxi = i else: if mx < t: mx = t mxi = i for j in range(M): if c: resy[j] -= a[i][j] else: resx[j] -= a[i][j] #print(mx, mxi) if mxi == -1: return for j in range(M): if c: resy[j] += a[mxi][j] else: resx[j] += a[mxi][j] vis[mxi] = 1 solve(c ^ 1) solve(0) print(check())