結果

問題 No.2236 Lights Out On Simple Graph
ユーザー gew1fw
提出日時 2025-06-12 18:23:19
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,682 bytes
コンパイル時間 304 ms
コンパイル使用メモリ 82,128 KB
実行使用メモリ 68,648 KB
最終ジャッジ日時 2025-06-12 18:23:27
合計ジャッジ時間 7,079 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 8 TLE * 1 -- * 48
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, M = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    edges = []
    for _ in range(M):
        a = int(input[ptr])
        b = int(input[ptr+1])
        edges.append((a, b))
        ptr += 2
    c = list(map(int, input[ptr:ptr+N]))
    ptr += N

    # Find connected components
    adj = [[] for _ in range(N+1)]  # 1-based
    for a, b in edges:
        adj[a].append(b)
        adj[b].append(a)
    visited = [False] * (N+1)
    components = []
    for v in range(1, N+1):
        if not visited[v]:
            q = deque()
            q.append(v)
            visited[v] = True
            comp = []
            while q:
                u = q.popleft()
                comp.append(u)
                for nei in adj[u]:
                    if not visited[nei]:
                        visited[nei] = True
                        q.append(nei)
            components.append(comp)

    # Check sum parity for each component
    for comp in components:
        total = sum(c[v-1] for v in comp)
        if total % 2 != 0:
            print(-1)
            return

    total_ops = 0

    for comp in components:
        comp_edges = []
        for a, b in edges:
            if a in comp and b in comp:
                comp_edges.append((a, b))
        m_vars = len(comp_edges)
        if m_vars == 0:
            # Check if all c are 0
            for v in comp:
                if c[v-1] == 1:
                    print(-1)
                    return
            continue

        # Build the system
        matrix = []
        for v in comp:
            row_mask = 0
            for edge_idx, (a, b) in enumerate(comp_edges):
                if a == v or b == v:
                    row_mask |= (1 << edge_idx)
            rhs = c[v-1]
            matrix.append((row_mask, rhs))

        # Gaussian elimination
        mat = gaussian_elimination(matrix.copy(), m_vars)
        if mat is None:
            print(-1)
            return

        # Find x0
        x0 = find_x0(mat, m_vars)

        # Find null space basis
        basis = find_null_space_basis(mat, m_vars)

        # Compute minimal solution
        min_ops = minimal_solution(x0, basis)
        total_ops += min_ops

    print(total_ops)

def gaussian_elimination(matrix, m_vars):
    n_rows = len(matrix)
    current_pivot_row = 0
    for col in range(m_vars):
        pivot_row = None
        for r in range(current_pivot_row, n_rows):
            if (matrix[r][0] >> col) & 1:
                pivot_row = r
                break
        if pivot_row is None:
            continue
        matrix[current_pivot_row], matrix[pivot_row] = matrix[pivot_row], matrix[current_pivot_row]
        for r in range(n_rows):
            if r != current_pivot_row and ((matrix[r][0] >> col) & 1):
                matrix[r] = (matrix[r][0] ^ matrix[current_pivot_row][0], matrix[r][1] ^ matrix[current_pivot_row][1])
        current_pivot_row += 1
    for r in range(current_pivot_row, n_rows):
        if matrix[r][1] != 0:
            return None
    return matrix

def find_x0(matrix, m_vars):
    x0 = [0] * m_vars
    for row in matrix:
        mask, rhs = row
        pivot_col = None
        for col in range(m_vars):
            if (mask >> col) & 1:
                pivot_col = col
                break
        if pivot_col is None:
            continue
        sum_val = 0
        mask_without_pivot = mask & (~ (1 << pivot_col))
        var = 0
        temp_mask = mask_without_pivot
        while temp_mask:
            if temp_mask & 1:
                sum_val ^= x0[var]
            temp_mask >>= 1
            var += 1
        x0[pivot_col] = (rhs ^ sum_val)
    return x0

def find_null_space_basis(matrix, m_vars):
    pivot_cols = set()
    for row in matrix:
        mask, rhs = row
        pivot_col = None
        for col in range(m_vars):
            if (mask >> col) & 1:
                pivot_col = col
                break
        if pivot_col is not None:
            pivot_cols.add(pivot_col)
    free_vars = [col for col in range(m_vars) if col not in pivot_cols]
    basis = []
    for free in free_vars:
        y = [0] * m_vars
        y[free] = 1
        for row in matrix:
            mask, rhs = row
            pivot_col = None
            for col in range(m_vars):
                if (mask >> col) & 1:
                    pivot_col = col
                    break
            if pivot_col is None:
                continue
            sum_val = 0
            mask_without_pivot = mask & (~ (1 << pivot_col))
            var = 0
            temp_mask = mask_without_pivot
            while temp_mask:
                if temp_mask & 1:
                    sum_val ^= y[var]
                temp_mask >>= 1
                var += 1
            y[pivot_col] = sum_val
        basis.append(y)
    return basis

def minimal_solution(x0, basis):
    if not basis:
        return sum(x0)
    x0_mask = 0
    for i in range(len(x0)):
        if x0[i]:
            x0_mask |= (1 << i)
    basis_masks = []
    for vec in basis:
        mask = 0
        for i in range(len(vec)):
            if vec[i]:
                mask |= (1 << i)
        basis_masks.append(mask)
    k = len(basis_masks)
    min_count = float('inf')
    for mask in range(0, 1 << k):
        current_sum = 0
        for i in range(k):
            if (mask >> i) & 1:
                current_sum ^= basis_masks[i]
        solution = x0_mask ^ current_sum
        count = bin(solution).count('1')
        if count < min_count:
            min_count = count
    return min_count

if __name__ == '__main__':
    main()
0