n, m = map(int, input().split()) chests = [list(map(int, input().split())) for _ in range(n)] # Calculate Tj for each jewel type T = [0] * m for j in range(m): for i in range(n): T[j] += chests[i][j] # Calculate the value of each chest values = [] for chest in chests: value = sum(chest[j] * T[j] for j in range(m)) values.append(value) # Sort values in descending order values.sort(reverse=True) # Calculate the sum for the first player (0th, 2nd, 4th... elements) sum_S = 0 for i in range(0, n, 2): sum_S += values[i] # Calculate the final answer sum_T_sq = sum(t * t for t in T) result = 2 * sum_S - sum_T_sq print(result)