import sys input = sys.stdin.readline N, M = map(int, input().split()) a = [list(map(int, input().split())) for _ in range(N)] table = [0] * M for i in range(N): for j in range(M): table[j] += a[i][j] a.sort(key = lambda x: -sum([x[i] * table[i] for i in range(M)])) res = 0 for i in range(N): x = a[i] if i % 2 == 0: res += sum([x[i] * table[i] for i in range(M)]) else: res -= sum([x[i] * table[i] for i in range(M)]) print(res)