結果
| 問題 |
No.2505 matriX cOnstRuction
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-07-27 16:46:57 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,591 ms / 2,500 ms |
| コード長 | 2,513 bytes |
| コンパイル時間 | 399 ms |
| コンパイル使用メモリ | 82,176 KB |
| 実行使用メモリ | 380,552 KB |
| 最終ジャッジ日時 | 2024-09-15 18:49:47 |
| 合計ジャッジ時間 | 30,280 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge6 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 64 |
ソースコード
from itertools import product
import sys
from time import perf_counter
from typing import List, Tuple
def floor_pow2(n: int):
x = 1
while (x << 1) <= n:
x <<= 1
return x
class Node:
def __init__(self) -> None:
self.lch = None
self.rch = None
self.weight = 0
def left_child_or_create(self) -> 'Node':
if self.lch is None:
self.lch = Node()
return self.lch
def right_child_or_create(self) -> 'Node':
if self.rch is None:
self.rch = Node()
return self.rch
L = 30
inf = 1 << 30
def solve(n: int, m: int, R: List[int], C: List[int], A: List[List[int]]):
for i, j in product(range(n), range(m)):
if A[0][0] ^ A[0][j] ^ A[i][0] ^ A[i][j]:
print(-1)
return
root = Node()
# f(X) += W * [X ^ Y > Z]
def add_weight(Y: int, Z: int, W: int):
YZ = Y ^ Z
cur = root
for bit in reversed(range(L)):
if (YZ >> bit) & 1:
nxt = cur.right_child_or_create()
else:
nxt = cur.left_child_or_create()
if not ((Z >> bit) & 1):
cur.weight += W
nxt.weight -= W
cur = nxt
for i in range(n):
Y = A[0][0] ^ A[i][0]
if R[i]:
add_weight(Y, 0, 1)
add_weight(Y, R[i], 1)
add_weight(Y, 2 * floor_pow2(R[i]) - 1, inf)
else:
add_weight(Y, 0, inf)
for j in range(m):
Y = A[0][j]
if C[j]:
add_weight(Y, 0, 1)
add_weight(Y, C[j], 1)
add_weight(Y, 2 * floor_pow2(C[j]) - 1, inf)
else:
add_weight(Y, 0, inf)
min_weight = inf
q : List[Node] = [root]
for node in q:
weight = node.weight
if node.lch is None:
min_weight = min(min_weight, weight)
else:
node.lch.weight += weight
q.append(node.lch)
if node.rch is None:
min_weight = min(min_weight, weight)
else:
node.rch.weight += weight
q.append(node.rch)
if min_weight >= inf:
print(-1)
return
print(min_weight)
input = sys.stdin.readline
T = int(input())
for _ in range(T):
n, m = map(int, input().split())
R = list(map(int, input().split()))
C = list(map(int, input().split()))
A = [list(map(int, input().split())) for _ in range(n)]
solve(n, m, R, C, A)