結果

問題 No.2236 Lights Out On Simple Graph
ユーザー gew1fw
提出日時 2025-06-12 19:01:20
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,259 bytes
コンパイル時間 207 ms
コンパイル使用メモリ 82,292 KB
実行使用メモリ 76,284 KB
最終ジャッジ日時 2025-06-12 19:01:45
合計ジャッジ時間 6,367 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 8 TLE * 1 -- * 48
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    N, M = map(int, sys.stdin.readline().split())
    edges = []
    for _ in range(M):
        a, b = map(int, sys.stdin.readline().split())
        a -= 1
        b -= 1
        edges.append((a, b))
    c = list(map(int, sys.stdin.readline().split()))

    total = sum(c)
    if total % 2 != 0:
        print(-1)
        return

    mat = []
    for i in range(N):
        row = 0
        for j in range(M):
            a, b = edges[j]
            if i == a or i == b:
                row ^= (1 << j)
        row |= (c[i] << M)
        mat.append(row)

    rank = 0
    for col in range(M):
        pivot = -1
        for r in range(rank, N):
            if (mat[r] >> col) & 1:
                pivot = r
                break
        if pivot == -1:
            continue
        mat[rank], mat[pivot] = mat[pivot], mat[rank]
        for r in range(N):
            if r != rank and (mat[r] >> col) & 1:
                mat[r] ^= mat[rank]
        rank += 1

    consistent = True
    for r in range(rank, N):
        if (mat[r] >> M) & 1:
            consistent = False
            break
    if not consistent:
        print(-1)
        return

    pivot_cols = []
    for r in range(rank):
        for col in range(M):
            if (mat[r] >> col) & 1:
                pivot_cols.append(col)
                break

    free_vars = sorted(set(range(M)) - set(pivot_cols))
    k = len(free_vars)
    min_weight = float('inf')

    for mask in range(0, 1 << k):
        x = [0] * M
        for i in range(k):
            var = free_vars[i]
            x[var] = (mask >> i) & 1

        for r in range(rank):
            row = mat[r]
            rhs = (row >> M) & 1
            pivot_col = -1
            for col in range(M):
                if (row >> col) & 1:
                    pivot_col = col
                    break
            sum_val = 0
            for col in range(pivot_col + 1, M):
                if (row >> col) & 1:
                    sum_val ^= x[col]
            x_pivot = (rhs ^ sum_val)
            x[pivot_col] = x_pivot

        weight = sum(x)
        if weight < min_weight:
            min_weight = weight

    print(min_weight if min_weight != float('inf') else -1)

if __name__ == "__main__":
    main()
0