結果
問題 |
No.1324 Approximate the Matrix
|
ユーザー |
![]() |
提出日時 | 2025-06-12 20:00:10 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,071 bytes |
コンパイル時間 | 200 ms |
コンパイル使用メモリ | 81,912 KB |
実行使用メモリ | 81,460 KB |
最終ジャッジ日時 | 2025-06-12 20:03:33 |
合計ジャッジ時間 | 5,738 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 7 WA * 35 |
ソースコード
import sys def ipf(P, A, B, max_iter=1000, eps=1e-6): N = len(P) Q = [row[:] for row in P] for _ in range(max_iter): row_sums = [sum(row) for row in Q] for i in range(N): if row_sums[i] == 0: continue scale = A[i] / row_sums[i] for j in range(N): Q[i][j] *= scale col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)] for j in range(N): if col_sums[j] == 0: continue scale = B[j] / col_sums[j] for i in range(N): Q[i][j] *= scale row_ok = all(abs(sum(Q[i]) - A[i]) < eps for i in range(N)) col_ok = all(abs(sum(Q[i][j] for i in range(N)) - B[j]) < eps for j in range(N)) if row_ok and col_ok: break return Q def adjust_to_integers(Q, A, B): N = len(Q) for i in range(N): for j in range(N): Q[i][j] = round(Q[i][j]) row_sums = [sum(row) for row in Q] col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)] for i in range(N): if row_sums[i] != A[i]: diff = A[i] - row_sums[i] j = 0 while diff != 0 and j < N: current = Q[i][j] if diff > 0: add = min(diff, B[j] - col_sums[j]) if add > 0: Q[i][j] += add diff -= add col_sums[j] += add else: remove = min(-diff, current) Q[i][j] -= remove diff += remove col_sums[j] -= remove j += 1 for j in range(N): if col_sums[j] != B[j]: diff = B[j] - col_sums[j] i = 0 while diff != 0 and i < N: current = Q[i][j] if diff > 0: add = min(diff, A[i] - sum(Q[i])) if add > 0: Q[i][j] += add diff -= add row_sums[i] += add else: remove = min(-diff, current) Q[i][j] -= remove diff += remove row_sums[i] -= remove i += 1 return Q def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr +=1 K = int(input[ptr]) ptr +=1 A = list(map(int, input[ptr:ptr+N])) ptr +=N B = list(map(int, input[ptr:ptr+N])) ptr +=N P = [] for _ in range(N): row = list(map(int, input[ptr:ptr+N])) ptr +=N P.append(row) Q_real = ipf(P, A, B) Q = [ [0]*N for _ in range(N) ] for i in range(N): for j in range(N): Q[i][j] = round(Q_real[i][j]) Q = adjust_to_integers(Q, A, B) total = 0 for i in range(N): for j in range(N): total += (P[i][j] - Q[i][j])**2 print(total) if __name__ == "__main__": main()