結果

問題 No.2236 Lights Out On Simple Graph
ユーザー gew1fw
提出日時 2025-06-12 14:03:55
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 5,710 bytes
コンパイル時間 384 ms
コンパイル使用メモリ 82,320 KB
実行使用メモリ 84,272 KB
最終ジャッジ日時 2025-06-12 14:04:39
合計ジャッジ時間 6,644 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 5 WA * 2 RE * 1 TLE * 1 -- * 48
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    N, M = map(int, sys.stdin.readline().split())
    edges = []
    for _ in range(M):
        a, b = map(int, sys.stdin.readline().split())
        edges.append((a-1, b-1))
    c = list(map(int, sys.stdin.readline().split()))
    
    # Find connected components
    adj = [[] for _ in range(N)]
    for a, b in edges:
        adj[a].append(b)
        adj[b].append(a)
    visited = [False] * N
    components = []
    for i in range(N):
        if not visited[i]:
            q = deque()
            q.append(i)
            visited[i] = True
            comp = []
            while q:
                u = q.popleft()
                comp.append(u)
                for v in adj[u]:
                    if not visited[v]:
                        visited[v] = True
                        q.append(v)
            components.append(comp)
    
    total = 0
    for comp in components:
        n_sub = len(comp)
        comp_vertices = comp
        comp_map = {v: idx for idx, v in enumerate(comp_vertices)}
        comp_edges = []
        for a, b in edges:
            if a in comp_map and b in comp_map:
                comp_edges.append((a, b))
        m_sub = len(comp_edges)
        
        sum_c = sum(c[v] for v in comp_vertices)
        if sum_c % 2 != 0:
            print(-1)
            return
        
        # Build the system A x = c in GF(2)
        # A is a n_sub x m_sub matrix
        # Each row corresponds to a vertex in comp_vertices
        # Each column corresponds to an edge in comp_edges
        system = []
        for v in comp_vertices:
            row = 0
            for j, (a, b) in enumerate(comp_edges):
                if a == v or b == v:
                    row ^= (1 << (m_sub - 1 - j))
            c_val = c[v]
            system.append( (row, c_val) )
        
        # Perform Gaussian elimination
        matrix = []
        for row, c_val in system:
            matrix_row = (row << 1) | c_val
            matrix.append(matrix_row)
        
        rank = 0
        n_rows = len(matrix)
        for col_idx in range(m_sub):
            pivot_row = None
            for r in range(rank, n_rows):
                if (matrix[r] >> (m_sub - col_idx)) & 1:
                    pivot_row = r
                    break
            if pivot_row is None:
                continue
            # Swap with the current rank row
            matrix[rank], matrix[pivot_row] = matrix[pivot_row], matrix[rank]
            # Eliminate this column in all other rows
            for r in range(n_rows):
                if r != rank and (matrix[r] >> (m_sub - col_idx)) & 1:
                    matrix[r] ^= matrix[rank]
            rank += 1
        
        # Check for inconsistency
        for row in matrix:
            a_part = row >> 1
            c_bit = row & 1
            if bin(a_part).count('1') == 0 and c_bit == 1:
                print(-1)
                return
        
        # Find the pivot columns
        pivot_cols = []
        for row in matrix:
            a_part = row >> 1
            if a_part == 0:
                continue
            # Find the leading column
            leading_col = -1
            for j in range(m_sub):
                if (a_part >> (m_sub - j - 1)) & 1:
                    leading_col = j
                    break
            if leading_col != -1:
                pivot_cols.append(leading_col)
        
        # Find particular solution x0
        x0 = 0
        pivot_cols_sorted = sorted(pivot_cols)
        for i in reversed(range(len(pivot_cols_sorted))):
            col = pivot_cols_sorted[i]
            found = False
            for r in range(n_rows):
                a_part = matrix[r] >> 1
                if (a_part >> (m_sub - col - 1)) & 1:
                    row = matrix[r]
                    found = True
                    break
            if not found:
                continue
            sum_val = 0
            c_bit = (row & 1)
            for j in range(col + 1, m_sub):
                if (row >> (m_sub - j)) & 1:
                    sum_val ^= (x0 >> j) & 1
            x0 ^= ( (c_bit ^ sum_val) << col )
        
        # Find basis for null space
        basis = []
        for j in range(m_sub):
            if j not in pivot_cols:
                vec = 0
                vec |= (1 << j)
                for r in range(n_rows):
                    a_part = matrix[r] >> 1
                    if (a_part >> (m_sub - j - 1)) & 1:
                        leading_col = -1
                        for c in range(m_sub):
                            if (a_part >> (m_sub - c - 1)) & 1:
                                leading_col = c
                                break
                        if leading_col != -1:
                            vec ^= (1 << leading_col)
                # Check if vec is non-zero and add to basis
                if vec != 0:
                    # Check linear independence
                    tmp = vec
                    for b in basis:
                        tmp ^= (tmp & b) & b
                    if tmp != 0:
                        basis.append(tmp)
        
        # Generate all possible combinations of basis vectors
        min_w = bin(x0).count('1')
        k = len(basis)
        for mask in range(0, 1 << k):
            comb = 0
            for i in range(k):
                if (mask >> i) & 1:
                    comb ^= basis[i]
            x_candidate = x0 ^ comb
            cnt = bin(x_candidate).count('1')
            if cnt < min_w:
                min_w = cnt
        total += min_w
    
    print(total)

if __name__ == "__main__":
    main()
0