結果
問題 |
No.1669 パズル作成
|
ユーザー |
![]() |
提出日時 | 2025-04-09 21:04:45 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,900 bytes |
コンパイル時間 | 224 ms |
コンパイル使用メモリ | 82,596 KB |
実行使用メモリ | 109,108 KB |
最終ジャッジ日時 | 2025-04-09 21:06:22 |
合計ジャッジ時間 | 5,192 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 13 WA * 16 |
ソースコード
import sys from collections import defaultdict class PotentialUnionFind: def __init__(self, size): self.parent = list(range(size)) self.rank = [0] * size self.delta = [0] * size # delta[i] is the XOR difference between i and parent[i] def find(self, x): if self.parent[x] == x: return x orig_parent = self.parent[x] self.parent[x] = self.find(orig_parent) self.delta[x] ^= self.delta[orig_parent] return self.parent[x] def unite(self, x, y, w): rx = self.find(x) ry = self.find(y) if rx == ry: if (self.delta[x] ^ self.delta[y]) != w: return False return True if self.rank[rx] < self.rank[ry]: rx, ry = ry, rx x, y = y, x w ^= self.delta[x] ^ self.delta[y] self.parent[ry] = rx self.delta[ry] = self.delta[y] ^ w ^ self.delta[x] if self.rank[rx] == self.rank[ry]: self.rank[rx] += 1 return True def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 M = int(input[ptr]) ptr += 1 blacks = [] for _ in range(M): r = int(input[ptr]) - 1 ptr += 1 c = int(input[ptr]) - 1 ptr += 1 blacks.append((r, c)) uf = PotentialUnionFind(2 * N) conflict = False for r, c in blacks: row_node = r col_node = N + c if not uf.unite(row_node, col_node, 0): conflict = True break if conflict: print(N * N - M) return component_info = defaultdict(lambda: {'rows': 0, 'cols': 0, 'black_count': 0}) for r in range(N): root = uf.find(r) component_info[root]['rows'] += 1 for c in range(N): col_node = N + c root = uf.find(col_node) component_info[root]['cols'] += 1 for r, c in blacks: row_node = r col_node = N + c root = uf.find(row_node) component_info[root]['black_count'] += 1 constant_part = 0 components = [] for key in component_info: ci = component_info[key] a = ci['rows'] b = ci['cols'] black = ci['black_count'] constant_part += a * b - black components.append((a, b)) components.sort(key=lambda x: -(x[0] + x[1])) total0_row, total0_col = 0, 0 total1_row, total1_col = 0, 0 loss = 0 for a, b in components: loss0 = a * total0_col + b * total0_row loss1 = a * total1_col + b * total1_row if loss0 <= loss1: loss += loss0 total0_row += a total0_col += b else: loss += loss1 total1_row += a total1_col += b total = constant_part + loss print(total) if __name__ == '__main__': main()