import sys input = sys.stdin.readline N, M = map(int, input().split()) a = [list(map(int, input().split())) for _ in range(N)] resx = [0] * M resy = [0] * M vis = [0] * N for i in range(N): if i % 2: mx = 0 mxj = 0 for j in range(N): if vis[j]: continue t = 0 for k in range(M): t += (resy[k] + a[j][k]) ** 2 if t > mx: mxj = j mx = t vis[mxj] = 1 for k in range(M): resy[k] += a[mxj][k] else: mx = 0 mxj = 0 for j in range(N): if vis[j]: continue t = 0 for k in range(M): t += (resx[k] + a[j][k]) ** 2 if t > mx: mxj = j mx = t vis[mxj] = 1 for k in range(M): resx[k] += a[mxj][k] rres = 0 for k in range(M): rres += resx[k] ** 2 - resy[k] ** 2 print(rres)