結果

問題 No.1345 Beautiful BINGO
ユーザー gew1fw
提出日時 2025-06-12 13:24:41
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,986 bytes
コンパイル時間 4,351 ms
コンパイル使用メモリ 82,068 KB
実行使用メモリ 91,728 KB
最終ジャッジ日時 2025-06-12 13:30:39
合計ジャッジ時間 20,195 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 46 WA * 15
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    M = int(input[idx])
    idx += 1
    A = []
    for _ in range(N):
        row = list(map(int, input[idx:idx+N]))
        idx += N
        A.append(row)
    
    # Precompute sum_rows and sum_rowcol
    sum_rows = [0] * (1 << N)
    sum_rowcol = [[0] * N for _ in range(1 << N)]
    for mask in range(1 << N):
        s = 0
        for i in range(N):
            if mask & (1 << i):
                for j in range(N):
                    sum_rowcol[mask][j] += A[i][j]
                s += sum(A[i])
        sum_rows[mask] = s
    
    # Precompute sum_col
    sum_col = [0] * N
    for j in range(N):
        for i in range(N):
            sum_col[j] += A[i][j]
    
    # Precompute sum_diag1 and sum_diag2
    sum_diag1 = 0
    sum_diag2 = 0
    for i in range(N):
        sum_diag1 += A[i][i]
        sum_diag2 += A[i][N-1 - i]
    
    current_min = float('inf')
    
    for row_mask in range(1 << N):
        r = bin(row_mask).count('1')
        sum_r = sum_rows[row_mask]
        val = []
        for j in range(N):
            # val_j is sum_col[j] - sum_rowcol[row_mask][j]
            val_j = sum_col[j] - sum_rowcol[row_mask][j]
            val.append((val_j, j))
        # Sort by val_j ascending
        val.sort()
        # Compute prefix sums
        prefix = [0] * (N + 1)
        for c in range(1, N + 1):
            prefix[c] = prefix[c-1] + val[c-1][0]
        # Iterate over c from 0 to N
        for c in range(0, N + 1):
            sum_c = prefix[c]
            sum_union = sum_r + sum_c
            line_count = r + c
            # Compute selected_columns (first c columns in sorted list)
            selected_columns = set()
            for i in range(c):
                selected_columns.add(val[i][1])
            # Compute sum_overlap_diag1 and sum_overlap_diag2
            sum_od1 = 0
            for i in range(N):
                # Check if row i is selected or column i is selected
                if (row_mask & (1 << i)) or (i in selected_columns):
                    sum_od1 += A[i][i]
            sum_od2 = 0
            for i in range(N):
                j_col = N - 1 - i
                # Check if row i is selected or column j_col is selected
                if (row_mask & (1 << i)) or (j_col in selected_columns):
                    sum_od2 += A[i][j_col]
            # Check all combinations of D1 and D2
            for D1 in [0, 1]:
                for D2 in [0, 1]:
                    new_line_count = line_count + D1 + D2
                    if new_line_count >= M:
                        add_d1 = (sum_diag1 - sum_od1) * D1
                        add_d2 = (sum_diag2 - sum_od2) * D2
                        total_sum = sum_union + add_d1 + add_d2
                        if total_sum < current_min:
                            current_min = total_sum
    print(current_min)

if __name__ == '__main__':
    main()
0