結果

問題 No.2236 Lights Out On Simple Graph
ユーザー gew1fw
提出日時 2025-06-12 19:03:20
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,084 bytes
コンパイル時間 236 ms
コンパイル使用メモリ 82,604 KB
実行使用メモリ 84,548 KB
最終ジャッジ日時 2025-06-12 19:03:27
合計ジャッジ時間 6,558 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 8 TLE * 1 -- * 48
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import product

def main():
    N, M = map(int, sys.stdin.readline().split())
    edges = []
    for _ in range(M):
        a, b = map(int, sys.stdin.readline().split())
        edges.append((a, b))
    c = list(map(int, sys.stdin.readline().split()))
    
    sum_c = sum(c)
    if sum_c % 2 != 0:
        print(-1)
        return
    
    # Initialize matrix A and augmented matrix
    A = [[0] * M for _ in range(N)]
    for j in range(M):
        a, b = edges[j]
        a -= 1
        b -= 1
        A[a][j] = 1
        A[b][j] = 1
    
    # Construct augmented matrix
    aug = [row.copy() + [c[i]] for i, row in enumerate(A)]
    n_rows = N
    n_cols = M
    rank = 0
    pivot_cols = set()
    
    for col in range(n_cols):
        # Find the pivot row
        row = None
        for r in range(rank, n_rows):
            if aug[r][col] == 1:
                row = r
                break
        if row is None:
            continue
        
        # Swap rows
        aug[rank], aug[row] = aug[row], aug[rank]
        
        # Eliminate this column in other rows
        for r in range(n_rows):
            if r != rank and aug[r][col] == 1:
                for c_col in range(col, n_cols + 1):
                    aug[r][c_col] ^= aug[rank][c_col]
        
        pivot_cols.add(col)
        rank += 1
    
    # Check for inconsistency
    for r in range(rank, n_rows):
        if aug[r][n_cols] == 1:
            print(-1)
            return
    
    # Find particular solution x0
    x0 = [0] * M
    for r in range(rank):
        # Find pivot column
        pivot_col = None
        for j in range(n_cols):
            if aug[r][j] == 1:
                pivot_col = j
                break
        if pivot_col is None:
            continue
        
        sum_eq = aug[r][n_cols]
        for j in range(n_cols):
            if j == pivot_col:
                continue
            sum_eq ^= (x0[j] & aug[r][j])
        x0[pivot_col] = sum_eq % 2
    
    # Compute basis for null space
    null_basis = []
    free_vars = [j for j in range(M) if j not in pivot_cols]
    for j in free_vars:
        vec = [0] * M
        vec[j] = 1
        for r in range(rank):
            pivot_col = None
            for k in range(M):
                if aug[r][k] == 1:
                    pivot_col = k
                    break
            sum_eq = 0
            for k in range(M):
                if k == pivot_col:
                    continue
                sum_eq ^= (vec[k] & aug[r][k])
            vec[pivot_col] = sum_eq % 2
        null_basis.append(vec)
    
    min_weight = sum(x0)
    k = len(null_basis)
    if k == 0:
        print(min_weight)
        return
    
    for bits in product([0, 1], repeat=k):
        current = x0.copy()
        for i in range(k):
            if bits[i]:
                for j in range(M):
                    current[j] ^= null_basis[i][j]
        weight = sum(current)
        if weight < min_weight:
            min_weight = weight
    
    print(min_weight)

if __name__ == "__main__":
    main()
0