結果
| 問題 |
No.1324 Approximate the Matrix
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 14:55:45 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,071 bytes |
| コンパイル時間 | 184 ms |
| コンパイル使用メモリ | 82,884 KB |
| 実行使用メモリ | 81,804 KB |
| 最終ジャッジ日時 | 2025-06-12 14:58:02 |
| 合計ジャッジ時間 | 4,779 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 7 WA * 35 |
ソースコード
import sys
def ipf(P, A, B, max_iter=1000, eps=1e-6):
N = len(P)
Q = [row[:] for row in P]
for _ in range(max_iter):
row_sums = [sum(row) for row in Q]
for i in range(N):
if row_sums[i] == 0:
continue
scale = A[i] / row_sums[i]
for j in range(N):
Q[i][j] *= scale
col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)]
for j in range(N):
if col_sums[j] == 0:
continue
scale = B[j] / col_sums[j]
for i in range(N):
Q[i][j] *= scale
row_ok = all(abs(sum(Q[i]) - A[i]) < eps for i in range(N))
col_ok = all(abs(sum(Q[i][j] for i in range(N)) - B[j]) < eps for j in range(N))
if row_ok and col_ok:
break
return Q
def adjust_to_integers(Q, A, B):
N = len(Q)
for i in range(N):
for j in range(N):
Q[i][j] = round(Q[i][j])
row_sums = [sum(row) for row in Q]
col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)]
for i in range(N):
if row_sums[i] != A[i]:
diff = A[i] - row_sums[i]
j = 0
while diff != 0 and j < N:
current = Q[i][j]
if diff > 0:
add = min(diff, B[j] - col_sums[j])
if add > 0:
Q[i][j] += add
diff -= add
col_sums[j] += add
else:
remove = min(-diff, current)
Q[i][j] -= remove
diff += remove
col_sums[j] -= remove
j += 1
for j in range(N):
if col_sums[j] != B[j]:
diff = B[j] - col_sums[j]
i = 0
while diff != 0 and i < N:
current = Q[i][j]
if diff > 0:
add = min(diff, A[i] - sum(Q[i]))
if add > 0:
Q[i][j] += add
diff -= add
row_sums[i] += add
else:
remove = min(-diff, current)
Q[i][j] -= remove
diff += remove
row_sums[i] -= remove
i += 1
return Q
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr +=1
K = int(input[ptr])
ptr +=1
A = list(map(int, input[ptr:ptr+N]))
ptr +=N
B = list(map(int, input[ptr:ptr+N]))
ptr +=N
P = []
for _ in range(N):
row = list(map(int, input[ptr:ptr+N]))
ptr +=N
P.append(row)
Q_real = ipf(P, A, B)
Q = [ [0]*N for _ in range(N) ]
for i in range(N):
for j in range(N):
Q[i][j] = round(Q_real[i][j])
Q = adjust_to_integers(Q, A, B)
total = 0
for i in range(N):
for j in range(N):
total += (P[i][j] - Q[i][j])**2
print(total)
if __name__ == "__main__":
main()
gew1fw