結果

問題 No.1324 Approximate the Matrix
ユーザー gew1fw
提出日時 2025-06-12 20:00:10
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,071 bytes
コンパイル時間 200 ms
コンパイル使用メモリ 81,912 KB
実行使用メモリ 81,460 KB
最終ジャッジ日時 2025-06-12 20:03:33
合計ジャッジ時間 5,738 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 7 WA * 35
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def ipf(P, A, B, max_iter=1000, eps=1e-6):
    N = len(P)
    Q = [row[:] for row in P]

    for _ in range(max_iter):
        row_sums = [sum(row) for row in Q]
        for i in range(N):
            if row_sums[i] == 0:
                continue
            scale = A[i] / row_sums[i]
            for j in range(N):
                Q[i][j] *= scale

        col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)]
        for j in range(N):
            if col_sums[j] == 0:
                continue
            scale = B[j] / col_sums[j]
            for i in range(N):
                Q[i][j] *= scale

        row_ok = all(abs(sum(Q[i]) - A[i]) < eps for i in range(N))
        col_ok = all(abs(sum(Q[i][j] for i in range(N)) - B[j]) < eps for j in range(N))
        if row_ok and col_ok:
            break
    return Q

def adjust_to_integers(Q, A, B):
    N = len(Q)
    for i in range(N):
        for j in range(N):
            Q[i][j] = round(Q[i][j])

    row_sums = [sum(row) for row in Q]
    col_sums = [sum(Q[i][j] for i in range(N)) for j in range(N)]

    for i in range(N):
        if row_sums[i] != A[i]:
            diff = A[i] - row_sums[i]
            j = 0
            while diff != 0 and j < N:
                current = Q[i][j]
                if diff > 0:
                    add = min(diff, B[j] - col_sums[j])
                    if add > 0:
                        Q[i][j] += add
                        diff -= add
                        col_sums[j] += add
                else:
                    remove = min(-diff, current)
                    Q[i][j] -= remove
                    diff += remove
                    col_sums[j] -= remove
                j += 1

    for j in range(N):
        if col_sums[j] != B[j]:
            diff = B[j] - col_sums[j]
            i = 0
            while diff != 0 and i < N:
                current = Q[i][j]
                if diff > 0:
                    add = min(diff, A[i] - sum(Q[i]))
                    if add > 0:
                        Q[i][j] += add
                        diff -= add
                        row_sums[i] += add
                else:
                    remove = min(-diff, current)
                    Q[i][j] -= remove
                    diff += remove
                    row_sums[i] -= remove
                i += 1

    return Q

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1
    K = int(input[ptr])
    ptr +=1

    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N

    B = list(map(int, input[ptr:ptr+N]))
    ptr +=N

    P = []
    for _ in range(N):
        row = list(map(int, input[ptr:ptr+N]))
        ptr +=N
        P.append(row)

    Q_real = ipf(P, A, B)
    Q = [ [0]*N for _ in range(N) ]
    for i in range(N):
        for j in range(N):
            Q[i][j] = round(Q_real[i][j])

    Q = adjust_to_integers(Q, A, B)

    total = 0
    for i in range(N):
        for j in range(N):
            total += (P[i][j] - Q[i][j])**2

    print(total)

if __name__ == "__main__":
    main()
0