import sys def main(): input = sys.stdin.read().split() idx = 0 N, K = int(input[idx]), int(input[idx+1]) idx +=2 A = list(map(int, input[idx:idx+N])) idx +=N B = list(map(int, input[idx:idx+N])) idx +=N P = [] for i in range(N): row = list(map(int, input[idx:idx+N])) idx +=N P.append(row) u = [0.0]*N v = [0.0]*N EPS = 1e-8 iterations = 100 for _ in range(iterations): # Update u for i in range(N): target = 2.0 * A[i] D = [2.0 * P[i][j] + v[j] for j in range(N)] left = -1e20 right = 1e20 for __ in range(100): mid = (left + right)/2 total = 0.0 for dj in D: val = mid + dj if val > 0: total += val if total < target - EPS: left = mid else: right = mid u[i] = (left + right)/2 # Update v for j in range(N): target = 2.0 * B[j] C = [2.0 * P[i][j] + u[i] for i in range(N)] left = -1e20 right = 1e20 for __ in range(100): mid = (left + right)/2 total = 0.0 for ci in C: val = mid + ci if val >0: total += val if total < target - EPS: left = mid else: right = mid v[j] = (left + right)/2 Q = [[0]*N for _ in range(N)] for i in range(N): for j in range(N): Q[i][j] = max(0.0, (u[i] + v[j] + 2.0 * P[i][j]) / 2.0 ) res = 0.0 for i in range(N): for j in range(N): diff = Q[i][j] - P[i][j] res += diff * diff print(int(round(res))) if __name__ == '__main__': main()