結果

問題 No.1296 OR or NOR
ユーザー lam6er
提出日時 2025-04-09 21:01:14
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,787 bytes
コンパイル時間 235 ms
コンパイル使用メモリ 82,716 KB
実行使用メモリ 268,792 KB
最終ジャッジ日時 2025-04-09 21:02:30
合計ジャッジ時間 8,549 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other TLE * 1 -- * 32
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    a = list(map(int, sys.stdin.readline().split()))
    Q = int(sys.stdin.readline())
    queries = list(map(int, sys.stdin.readline().split()))
    
    ones = [[] for _ in range(60)]
    for i in range(N):
        for j in range(60):
            if (a[i] >> j) & 1:
                ones[j].append(i)
    
    for b in queries:
        mandatory = [-1] * N
        conflict = False
        equations = []
        S_list = []
        for j in range(60):
            desired = (b >> j) & 1
            ones_j = ones[j]
            if not ones_j:
                continue
            last_i = ones_j[-1]
            if desired:
                if mandatory[last_i] == 1:
                    conflict = True
                    break
                mandatory[last_i] = 0
            else:
                if mandatory[last_i] == 0:
                    conflict = True
                    break
                mandatory[last_i] = 1
            S_j = []
            for i in range(last_i + 1, N):
                if not (a[i] >> j) & 1:
                    S_j.append(i)
            S_list.append(S_j)
        if conflict:
            print(-1)
            continue
        
        no_ones_j_list = []
        sum_mandatory_op2 = sum(1 for i in range(N) if mandatory[i] == 1)
        for j in range(60):
            ones_j = ones[j]
            if ones_j:
                continue
            desired = (b >> j) & 1
            total_op2 = sum_mandatory_op2
            required_parity = (desired - total_op2) % 2
            no_ones_j_list.append(required_parity)
        
        if no_ones_j_list:
            required = no_ones_j_list[0]
            if not all(r == required for r in no_ones_j_list):
                print(-1)
                continue
        
        mandatory_op2_count = sum_mandatory_op2
        var_steps = [i for i in range(N) if mandatory[i] == -1]
        step_to_var = {i: idx for idx, i in enumerate(var_steps)}
        M = len(var_steps)
        eqs = []
        
        for S_j in S_list:
            vec = 0
            for i in S_j:
                if mandatory[i] == -1:
                    vec ^= (1 << step_to_var[i])
            if vec != 0 or 0:  
                eqs.append((vec, 0))
        
        if no_ones_j_list:
            required = no_ones_j_list[0]
            vec = 0
            for i in var_steps:
                vec ^= (1 << step_to_var[i])
            eqs.append((vec, required))
        
        rank, rows = 0, []
        for (vec, rhs) in eqs:
            row = (vec << 1) | rhs
            pivot = -1
            for i in range(M-1, -1, -1):
                if (row >> (i+1)) & 1:
                    pivot = i
                    break
            if pivot == -1:
                if row & 1:
                    rank = -1
                    break
                else:
                    continue
            rows.append(row)
            for r in rows[:-1]:
                if (r >> (pivot + 1)) & 1:
                    rows[-1] ^= r
            new_row = rows[-1]
            pivot_new = -1
            for i in range(M-1, -1, -1):
                if (new_row >> (i+1)) & 1:
                    pivot_new = i
                    break
            if pivot_new == -1:
                if new_row & 1:
                    rank = -1
                    break
                else:
                    rows.pop()
                    continue
            for i in range(len(rows) - 1):
                r = rows[i]
                if (r >> (pivot_new + 1)) & 1:
                    rows[i] ^= new_row
            rows.sort(reverse=True)
            rank = len(rows)
        
        if rank == -1:
            print(-1)
            continue
        
        for i in range(len(rows)):
            for j in range(i+1, len(rows)):
                if rows[j] > rows[i]:
                    rows[i], rows[j] = rows[j], rows[i]
        
        extra = 0
        assigned = {}
        for row in reversed(rows):
            vec = row >> 1
            rhs = row & 1
            pivot = -1
            for i in range(M):
                if (vec >> i) & 1:
                    pivot = i
                    break
            if pivot == -1:
                continue
            lhs = 0
            for i in range(pivot+1, M):
                if (vec >> i) & 1:
                    lhs ^= assigned.get(i, 0)
            val = (rhs ^ lhs) & 1
            assigned[pivot] = val
        
        res = mandatory_op2_count
        for i in var_steps:
            var_idx = step_to_var[i]
            if assigned.get(var_idx, 0):
                res += 1
        print(res if res >=0 else -1)
    return

main()
0