結果

問題 No.1345 Beautiful BINGO
ユーザー lam6er
提出日時 2025-04-15 22:25:03
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,639 bytes
コンパイル時間 365 ms
コンパイル使用メモリ 81,944 KB
実行使用メモリ 77,172 KB
最終ジャッジ日時 2025-04-15 22:27:16
合計ジャッジ時間 5,036 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 16 WA * 45
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    M = int(input[idx])
    idx += 1
    matrix = []
    for _ in range(N):
        row = list(map(int, input[idx:idx+N]))
        idx += N
        matrix.append(row)
    
    # Precompute row costs and sort
    rows = [sum(row) for row in matrix]
    sorted_rows = sorted(enumerate(rows), key=lambda x: (x[1], x[0]))
    
    # Precompute column costs and sort
    cols = []
    for j in range(N):
        col_sum = sum(matrix[i][j] for i in range(N))
        cols.append(col_sum)
    sorted_cols = sorted(enumerate(cols), key=lambda x: (x[1], x[0]))
    
    # Precompute rows_selected and cols_selected
    rows_selected = [[] for _ in range(N+1)]
    for r in range(1, N+1):
        rows_selected[r] = rows_selected[r-1] + [sorted_rows[r-1][0]]
    
    cols_selected = [[] for _ in range(N+1)]
    for c in range(1, N+1):
        cols_selected[c] = cols_selected[c-1] + [sorted_cols[c-1][0]]
    
    # Precompute sum_intersect
    sum_intersect = [[0]*(N+1) for _ in range(N+1)]
    for r in range(N+1):
        for c in range(N+1):
            total = 0
            for i in rows_selected[r]:
                for j in cols_selected[c]:
                    total += matrix[i][j]
            sum_intersect[r][c] = total
    
    # Precompute sum_rows and sum_cols
    sum_rows = [0]*(N+1)
    current_sum = 0
    for r in range(1, N+1):
        current_sum += sorted_rows[r-1][1]
        sum_rows[r] = current_sum
    
    sum_cols = [0]*(N+1)
    current_sum = 0
    for c in range(1, N+1):
        current_sum += sorted_cols[c-1][1]
        sum_cols[c] = current_sum
    
    min_total = float('inf')
    
    for r in range(N+1):
        for c in range(N+1):
            current_k = r + c
            if current_k > 2*N:
                continue
            cost_rc = sum_rows[r] + sum_cols[c] - sum_intersect[r][c]
            d_needed = max(M - current_k, 0)
            if d_needed > 2:
                continue
            
            # Compute diag1 and diag2 costs
            cost_d1 = 0
            for i in range(N):
                row_list = rows_selected[r]
                pos = bisect.bisect_left(row_list, i)
                row_in = pos < len(row_list) and row_list[pos] == i
                col_list = cols_selected[c]
                pos = bisect.bisect_left(col_list, i)
                col_in = pos < len(col_list) and col_list[pos] == i
                if not row_in and not col_in:
                    cost_d1 += matrix[i][i]
            
            cost_d2 = 0
            for i in range(N):
                j = N - 1 - i
                row_list = rows_selected[r]
                pos = bisect.bisect_left(row_list, i)
                row_in = pos < len(row_list) and row_list[pos] == i
                col_list = cols_selected[c]
                pos = bisect.bisect_left(col_list, j)
                col_in = pos < len(col_list) and col_list[pos] == j
                if not row_in and not col_in:
                    cost_d2 += matrix[i][j]
            
            diag_costs = []
            if cost_d1 > 0:
                diag_costs.append(cost_d1)
            if cost_d2 > 0:
                diag_costs.append(cost_d2)
            diag_costs.sort()
            
            actual_d = min(d_needed, len(diag_costs))
            total_cost = cost_rc + sum(diag_costs[:actual_d])
            if current_k + actual_d >= M:
                if total_cost < min_total:
                    min_total = total_cost
    
    print(min_total)

if __name__ == "__main__":
    main()
0