import sys input = sys.stdin.readline N, M = map(int, input().split()) a = [list(map(int, input().split())) for _ in range(N)] a.sort(key = lambda x: -sum([x[i] ** 2 for i in range(M)])) x = [0] * M y = [0] * M for i in range(N): if i % 2: for j in range(M): y[j] += a[i][j] else: for j in range(M): x[j] += a[i][j] res = 0 for i in range(M): res += x[i] ** 2 - y[i] ** 2 print(res)