結果

問題 No.1669 パズル作成
ユーザー lam6er
提出日時 2025-04-09 21:04:45
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,900 bytes
コンパイル時間 224 ms
コンパイル使用メモリ 82,596 KB
実行使用メモリ 109,108 KB
最終ジャッジ日時 2025-04-09 21:06:22
合計ジャッジ時間 5,192 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 13 WA * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

class PotentialUnionFind:
    def __init__(self, size):
        self.parent = list(range(size))
        self.rank = [0] * size
        self.delta = [0] * size  # delta[i] is the XOR difference between i and parent[i]

    def find(self, x):
        if self.parent[x] == x:
            return x
        orig_parent = self.parent[x]
        self.parent[x] = self.find(orig_parent)
        self.delta[x] ^= self.delta[orig_parent]
        return self.parent[x]

    def unite(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        if rx == ry:
            if (self.delta[x] ^ self.delta[y]) != w:
                return False
            return True
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
            x, y = y, x
            w ^= self.delta[x] ^ self.delta[y]
        self.parent[ry] = rx
        self.delta[ry] = self.delta[y] ^ w ^ self.delta[x]
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        return True

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    M = int(input[ptr])
    ptr += 1
    blacks = []
    for _ in range(M):
        r = int(input[ptr]) - 1
        ptr += 1
        c = int(input[ptr]) - 1
        ptr += 1
        blacks.append((r, c))
    
    uf = PotentialUnionFind(2 * N)
    conflict = False
    for r, c in blacks:
        row_node = r
        col_node = N + c
        if not uf.unite(row_node, col_node, 0):
            conflict = True
            break
    
    if conflict:
        print(N * N - M)
        return
    
    component_info = defaultdict(lambda: {'rows': 0, 'cols': 0, 'black_count': 0})
    for r in range(N):
        root = uf.find(r)
        component_info[root]['rows'] += 1
    for c in range(N):
        col_node = N + c
        root = uf.find(col_node)
        component_info[root]['cols'] += 1
    for r, c in blacks:
        row_node = r
        col_node = N + c
        root = uf.find(row_node)
        component_info[root]['black_count'] += 1
    
    constant_part = 0
    components = []
    for key in component_info:
        ci = component_info[key]
        a = ci['rows']
        b = ci['cols']
        black = ci['black_count']
        constant_part += a * b - black
        components.append((a, b))
    
    components.sort(key=lambda x: -(x[0] + x[1]))
    total0_row, total0_col = 0, 0
    total1_row, total1_col = 0, 0
    loss = 0
    for a, b in components:
        loss0 = a * total0_col + b * total0_row
        loss1 = a * total1_col + b * total1_row
        if loss0 <= loss1:
            loss += loss0
            total0_row += a
            total0_col += b
        else:
            loss += loss1
            total1_row += a
            total1_col += b
    total = constant_part + loss
    print(total)

if __name__ == '__main__':
    main()
0