結果
問題 | No.2505 matriX cOnstRuction |
ユーザー |
|
提出日時 | 2023-07-27 16:46:57 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,591 ms / 2,500 ms |
コード長 | 2,513 bytes |
コンパイル時間 | 399 ms |
コンパイル使用メモリ | 82,176 KB |
実行使用メモリ | 380,552 KB |
最終ジャッジ日時 | 2024-09-15 18:49:47 |
合計ジャッジ時間 | 30,280 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge6 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 64 |
ソースコード
from itertools import product import sys from time import perf_counter from typing import List, Tuple def floor_pow2(n: int): x = 1 while (x << 1) <= n: x <<= 1 return x class Node: def __init__(self) -> None: self.lch = None self.rch = None self.weight = 0 def left_child_or_create(self) -> 'Node': if self.lch is None: self.lch = Node() return self.lch def right_child_or_create(self) -> 'Node': if self.rch is None: self.rch = Node() return self.rch L = 30 inf = 1 << 30 def solve(n: int, m: int, R: List[int], C: List[int], A: List[List[int]]): for i, j in product(range(n), range(m)): if A[0][0] ^ A[0][j] ^ A[i][0] ^ A[i][j]: print(-1) return root = Node() # f(X) += W * [X ^ Y > Z] def add_weight(Y: int, Z: int, W: int): YZ = Y ^ Z cur = root for bit in reversed(range(L)): if (YZ >> bit) & 1: nxt = cur.right_child_or_create() else: nxt = cur.left_child_or_create() if not ((Z >> bit) & 1): cur.weight += W nxt.weight -= W cur = nxt for i in range(n): Y = A[0][0] ^ A[i][0] if R[i]: add_weight(Y, 0, 1) add_weight(Y, R[i], 1) add_weight(Y, 2 * floor_pow2(R[i]) - 1, inf) else: add_weight(Y, 0, inf) for j in range(m): Y = A[0][j] if C[j]: add_weight(Y, 0, 1) add_weight(Y, C[j], 1) add_weight(Y, 2 * floor_pow2(C[j]) - 1, inf) else: add_weight(Y, 0, inf) min_weight = inf q : List[Node] = [root] for node in q: weight = node.weight if node.lch is None: min_weight = min(min_weight, weight) else: node.lch.weight += weight q.append(node.lch) if node.rch is None: min_weight = min(min_weight, weight) else: node.rch.weight += weight q.append(node.rch) if min_weight >= inf: print(-1) return print(min_weight) input = sys.stdin.readline T = int(input()) for _ in range(T): n, m = map(int, input().split()) R = list(map(int, input().split())) C = list(map(int, input().split())) A = [list(map(int, input().split())) for _ in range(n)] solve(n, m, R, C, A)