import sys import numpy as np read = sys.stdin.buffer.read readline = sys.stdin.buffer.readline readlines = sys.stdin.buffer.readlines N, M = map(int, readline().split()) A = np.array(read().split(), np.int32).reshape(N, M) S = A.sum(axis=0) I = np.argsort((A * S).sum(axis=1)) A = A[I][::-1] me = A[::2].sum(axis=0) enj = A[1::2].sum(axis=0) x = np.dot(me, me) - np.dot(enj, enj) print(x)