n, m = map(int, input().split()) A = [list(map(int, input().split())) for _ in range(n)] tot = [0] * m for row in A: for i, a in enumerate(row): tot[i] += a for row in A: s = 0 for a, t in zip(row, tot): s += a * t row.append(s) A.sort(key=lambda x:-x[-1]) X = [0] * m Y = [0] * m t = 0 for row in A: if t == 0: for i, a in enumerate(row[:-1]): X[i] += a else: for i, a in enumerate(row[:-1]): Y[i] += a t ^= 1 def f(X): return sum(x ** 2 for x in X) score = f(X) - f(Y) print(score)