
問題 No.2901 Logical Sum of Substring
ユーザー 👑 hahhohahho
提出日時 2024-04-25 14:32:51
言語 PyPy3
実行時間 1,269 ms / 3,000 ms
コード長 4,463 bytes
コンパイル時間 265 ms
コンパイル使用メモリ 82,148 KB
実行使用メモリ 276,720 KB
最終ジャッジ日時 2024-09-20 20:51:09
合計ジャッジ時間 31,640 ms
judge2 / judge3


diff #

from typing import *

T = TypeVar('T')
BinaryOperator = Callable[[T, T], T]
E = TypeVar('E')

class SegmentTree(Generic[E]):

    def all_identity(cls, operator: BinaryOperator[E], identity: E, size: int) -> 'SegmentTree[E]':
        return cls(operator, identity, [identity] * (2 << (size - 1).bit_length()))

    def from_initial_data(cls, operator: BinaryOperator[E], identity: E, data: MutableSequence[E]) -> 'SegmentTree[E]':
        size = 1 << (len(data) - 1).bit_length()
        temp = [identity] * (2 * size)
        temp[size:size + len(data)] = data
        data = temp

        for i in reversed(range(size)):
            data[i] = operator(data[2 * i], data[2 * i + 1])
        return cls(operator, identity, data)

    # これ使わずファクトリーメソッド使いましょうね
    def __init__(self, operator: BinaryOperator[E], identity: E, data: MutableSequence[E]):
        self.op = operator
        self.id = identity
        self.data = data
        self.size = len(data) // 2

    def reduce(self, l: int, r: int) -> E:
        l += self.size
        r += self.size
        vl = self.id
        vr = self.id

        while l < r:
            if l & 1:
                vl = self.op(vl, self.data[l])
                l += 1
            if r & 1:
                r -= 1
                vr = self.op(self.data[r], vr)
            l >>= 1
            r >>= 1
        return self.op(vl, vr)

    def __getitem__(self, i: Union[slice, int]) -> E:
        if isinstance(i, slice):
            return self.reduce(
                0 if i.start is None else i.start,
                self.size if i.stop is None else i.stop)
        return self.data[i + self.size]

    def __setitem__(self, i: int, v: E):
        i += self.size
        while i:
            self.data[i] = v
            v = self.op(self.data[i ^ 1], v) if i & 1 else self.op(v, self.data[i ^ 1])
            i >>= 1

    def __iter__(self) -> Iterator[E]:
        return iter(self.data[self.size:])

def min2(x, y):
    return x if x < y else y

INF = 2**60

def solve(k, aa, queries):

    MASK = (1<<k)-1
    def op(x, y):
        x_prefix, x_suffix, x_opt, x_block = x
        y_prefix, y_suffix, y_opt, y_block = y
        z_opt = min2(x_opt, y_opt)
        i = 0
        for p, t in reversed(y_prefix):
            while i < len(x_suffix) and x_suffix[i][0]|p != MASK:
                i += 1
            if i >= len(x_suffix):
            z_opt = min2(z_opt, t + x_suffix[i][1])
        z_prefix = x_prefix[:]
        if z_prefix[-1][0] != MASK:
            for v, t in y_prefix:
                u = v|z_prefix[-1][0]
                if u != z_prefix[-1][0]:
                    z_prefix.append((u, t+x_block))
        z_suffix = y_suffix[:]
        if z_suffix[-1][0] != MASK:
            for v, l in x_suffix:
                u = v|z_suffix[-1][0]
                if u != z_suffix[-1][0]:
                    z_suffix.append((u, l+y_block))
        return z_prefix, z_suffix, z_opt, x_block+y_block

    seg = SegmentTree.from_initial_data(op,
                                        identity=([(0, 0)], [(0, 0)], INF, 0),
                                        data=[([(0, 0), (s, 1)] if s != 0 else [(0, 0)],
                                               [(0, 0), (s, 1)] if s != 0 else [(0, 0)],
                                               1 if s == MASK else INF,
                                               1) for i,s in enumerate(aa)]

    res = []
    for query in queries:
        mode, *tokens = query
        if mode == 1:
            i, v = tokens
            seg[i] = ([(0, 0), (v, 1)] if v != 0 else [(0, 0)],
                      [(0, 0), (v, 1)] if v != 0 else [(0, 0)],
                      1 if v == MASK else INF,
            l, r = tokens
            p = seg[l:r][2]
            res.append(p if p < INF else -1)
    return res

n, k = map(int,input().split())
aa = list(map(int,input().split()))
q = int(input())
queries = []
for _ in range(q):
    a,b,c = map(int,input().split())
print(*solve(k, aa, queries), sep='\n')