結果
問題 |
No.2236 Lights Out On Simple Graph
|
ユーザー |
![]() |
提出日時 | 2025-06-12 18:23:19 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 5,682 bytes |
コンパイル時間 | 304 ms |
コンパイル使用メモリ | 82,128 KB |
実行使用メモリ | 68,648 KB |
最終ジャッジ日時 | 2025-06-12 18:23:27 |
合計ジャッジ時間 | 7,079 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 8 TLE * 1 -- * 48 |
ソースコード
import sys from collections import deque def main(): input = sys.stdin.read().split() ptr = 0 N, M = int(input[ptr]), int(input[ptr+1]) ptr += 2 edges = [] for _ in range(M): a = int(input[ptr]) b = int(input[ptr+1]) edges.append((a, b)) ptr += 2 c = list(map(int, input[ptr:ptr+N])) ptr += N # Find connected components adj = [[] for _ in range(N+1)] # 1-based for a, b in edges: adj[a].append(b) adj[b].append(a) visited = [False] * (N+1) components = [] for v in range(1, N+1): if not visited[v]: q = deque() q.append(v) visited[v] = True comp = [] while q: u = q.popleft() comp.append(u) for nei in adj[u]: if not visited[nei]: visited[nei] = True q.append(nei) components.append(comp) # Check sum parity for each component for comp in components: total = sum(c[v-1] for v in comp) if total % 2 != 0: print(-1) return total_ops = 0 for comp in components: comp_edges = [] for a, b in edges: if a in comp and b in comp: comp_edges.append((a, b)) m_vars = len(comp_edges) if m_vars == 0: # Check if all c are 0 for v in comp: if c[v-1] == 1: print(-1) return continue # Build the system matrix = [] for v in comp: row_mask = 0 for edge_idx, (a, b) in enumerate(comp_edges): if a == v or b == v: row_mask |= (1 << edge_idx) rhs = c[v-1] matrix.append((row_mask, rhs)) # Gaussian elimination mat = gaussian_elimination(matrix.copy(), m_vars) if mat is None: print(-1) return # Find x0 x0 = find_x0(mat, m_vars) # Find null space basis basis = find_null_space_basis(mat, m_vars) # Compute minimal solution min_ops = minimal_solution(x0, basis) total_ops += min_ops print(total_ops) def gaussian_elimination(matrix, m_vars): n_rows = len(matrix) current_pivot_row = 0 for col in range(m_vars): pivot_row = None for r in range(current_pivot_row, n_rows): if (matrix[r][0] >> col) & 1: pivot_row = r break if pivot_row is None: continue matrix[current_pivot_row], matrix[pivot_row] = matrix[pivot_row], matrix[current_pivot_row] for r in range(n_rows): if r != current_pivot_row and ((matrix[r][0] >> col) & 1): matrix[r] = (matrix[r][0] ^ matrix[current_pivot_row][0], matrix[r][1] ^ matrix[current_pivot_row][1]) current_pivot_row += 1 for r in range(current_pivot_row, n_rows): if matrix[r][1] != 0: return None return matrix def find_x0(matrix, m_vars): x0 = [0] * m_vars for row in matrix: mask, rhs = row pivot_col = None for col in range(m_vars): if (mask >> col) & 1: pivot_col = col break if pivot_col is None: continue sum_val = 0 mask_without_pivot = mask & (~ (1 << pivot_col)) var = 0 temp_mask = mask_without_pivot while temp_mask: if temp_mask & 1: sum_val ^= x0[var] temp_mask >>= 1 var += 1 x0[pivot_col] = (rhs ^ sum_val) return x0 def find_null_space_basis(matrix, m_vars): pivot_cols = set() for row in matrix: mask, rhs = row pivot_col = None for col in range(m_vars): if (mask >> col) & 1: pivot_col = col break if pivot_col is not None: pivot_cols.add(pivot_col) free_vars = [col for col in range(m_vars) if col not in pivot_cols] basis = [] for free in free_vars: y = [0] * m_vars y[free] = 1 for row in matrix: mask, rhs = row pivot_col = None for col in range(m_vars): if (mask >> col) & 1: pivot_col = col break if pivot_col is None: continue sum_val = 0 mask_without_pivot = mask & (~ (1 << pivot_col)) var = 0 temp_mask = mask_without_pivot while temp_mask: if temp_mask & 1: sum_val ^= y[var] temp_mask >>= 1 var += 1 y[pivot_col] = sum_val basis.append(y) return basis def minimal_solution(x0, basis): if not basis: return sum(x0) x0_mask = 0 for i in range(len(x0)): if x0[i]: x0_mask |= (1 << i) basis_masks = [] for vec in basis: mask = 0 for i in range(len(vec)): if vec[i]: mask |= (1 << i) basis_masks.append(mask) k = len(basis_masks) min_count = float('inf') for mask in range(0, 1 << k): current_sum = 0 for i in range(k): if (mask >> i) & 1: current_sum ^= basis_masks[i] solution = x0_mask ^ current_sum count = bin(solution).count('1') if count < min_count: min_count = count return min_count if __name__ == '__main__': main()