結果

問題 No.2505 matriX cOnstRuction
ユーザー suisen
提出日時 2023-07-27 16:46:57
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,591 ms / 2,500 ms
コード長 2,513 bytes
コンパイル時間 399 ms
コンパイル使用メモリ 82,176 KB
実行使用メモリ 380,552 KB
最終ジャッジ日時 2024-09-15 18:49:47
合計ジャッジ時間 30,280 ms
ジャッジサーバーID
(参考情報)
judge3 / judge6
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 64
権限があれば一括ダウンロードができます

ソースコード

diff #

from itertools import product
import sys
from time import perf_counter
from typing import List, Tuple

def floor_pow2(n: int):
    x = 1
    while (x << 1) <= n:
        x <<= 1
    return x

class Node:
    def __init__(self) -> None:
        self.lch = None
        self.rch = None
        self.weight = 0
    
    def left_child_or_create(self) -> 'Node':
        if self.lch is None:
            self.lch = Node()
        return self.lch

    def right_child_or_create(self) -> 'Node':
        if self.rch is None:
            self.rch = Node()
        return self.rch

L = 30

inf = 1 << 30

def solve(n: int, m: int, R: List[int], C: List[int], A: List[List[int]]):
    for i, j in product(range(n), range(m)):
        if A[0][0] ^ A[0][j] ^ A[i][0] ^ A[i][j]:
            print(-1)
            return
        
    root = Node()

    # f(X) += W * [X ^ Y > Z]
    def add_weight(Y: int, Z: int, W: int):
        YZ = Y ^ Z
        cur = root
        for bit in reversed(range(L)):
            if (YZ >> bit) & 1:
                nxt = cur.right_child_or_create()
            else:
                nxt = cur.left_child_or_create()
            if not ((Z >> bit) & 1):
                cur.weight += W
                nxt.weight -= W
            cur = nxt

    for i in range(n):
        Y = A[0][0] ^ A[i][0]
        if R[i]:
            add_weight(Y, 0, 1)
            add_weight(Y, R[i], 1)
            add_weight(Y, 2 * floor_pow2(R[i]) - 1, inf)
        else:
            add_weight(Y, 0, inf)

    for j in range(m):
        Y = A[0][j]
        if C[j]:
            add_weight(Y, 0, 1)
            add_weight(Y, C[j], 1)
            add_weight(Y, 2 * floor_pow2(C[j]) - 1, inf)
        else:
            add_weight(Y, 0, inf)

    min_weight = inf
    
    q : List[Node] = [root]
    for node in q:
        weight = node.weight

        if node.lch is None:
            min_weight = min(min_weight, weight)
        else:
            node.lch.weight += weight
            q.append(node.lch)

        if node.rch is None:
            min_weight = min(min_weight, weight)
        else:
            node.rch.weight += weight
            q.append(node.rch)

    if min_weight >= inf:
        print(-1)
        return

    print(min_weight)

input = sys.stdin.readline

T = int(input())
for _ in range(T):
    n, m = map(int, input().split())
    R = list(map(int, input().split()))
    C = list(map(int, input().split()))
    A = [list(map(int, input().split())) for _ in range(n)]

    solve(n, m, R, C, A)
0