結果

問題 No.3423 Minimum Xor Query
コンテスト
ユーザー kidodesu
提出日時 2026-01-10 13:30:52
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 1,008 ms / 5,000 ms
コード長 8,613 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 264 ms
コンパイル使用メモリ 82,824 KB
実行使用メモリ 106,224 KB
最終ジャッジ日時 2026-01-11 13:15:22
合計ジャッジ時間 12,077 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 18
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right
class SortedMultiset:
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a = []):
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self):
        for i in self.a:
            for j in i: yield j

    def __reversed__(self):
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x):
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def count(self, x) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a, b, i = self._position(x)
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
    
    def _pop(self, a, b: int, i: int):
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True

    def lt(self, x):
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x):
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x):
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x):
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int):
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1):
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError

    def index(self, x) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

class minset:
    def __init__(self):
        self.layer0 = [0] * (1<<20)
        self.layer1 = [0] * (1<<10)

    def add(self, x):
        if x < N:
            self.layer0[x] += 1
            self.layer1[x>>10] += 1

    def dis(self, x):
        if x < N:
            self.layer0[x] -= 1
            self.layer1[x>>10] -= 1

    def min(self):
        x = 0
        while not self.layer1[x]:
            x += 1
        x <<= 10
        while not self.layer0[x]:
            x += 1
        return x

N = 1 << 20
ms = minset()
n, q = map(int, input().split())
A = list(map(int, input().split()))
A_idx = [i for i in range(n)]
q1 = q2 = 0; Q1 = []; Q2 = []
for i in range(q):
    Q_tmp = tuple(map(int, input().split()))
    if Q_tmp[0] == 1:
        Q1.append((Q_tmp[1]-1, Q_tmp[2]))
        A.append(Q_tmp[2])
        q1 += 1
    else:
        Q2.append((q1, Q_tmp[1], q2))
        q2 += 1


q1_ = int(((n*q1)/q2)**0.5)
q1_ = q1_ if q1_ else 1

Ans = [-1] * q2
A.extend([-1, 1<<22])
B = [i for i in range(n+q1+2)]
B.sort(key = lambda x: A[x])
B_val = [A[B[i]] for i in range(n+q1+2)]
B_val[0] = 1 << 21
B_idx = [-1] * (n+q1+2)
for i in range(1, n+q1+1):
    B_idx[B[i]] = i

B_use = SortedMultiset([i for i in range(n+q1+2) if B[i]<n or B[i] >= n+q1])
Q2.sort(key = lambda x: (-(x[0]//q1_), x[1]))
LR = [1<<30] * (n+q1+2)*2
p = 0
for i in range(1, n+q1+1):
    if B[i] < n:
        ms.add(B_val[p]^B_val[i])
        LR[2*i] = p
        LR[2*p+1] = i
        p = i
LR[2*p+1] = n+q1+1
LR[-2] = p
ms.add(B_val[p]^B_val[-1])

A = [B_idx[i] for i in range(n)]
log0 = [-1] * 7*(n+q1)
log1 = [-1] * 7*(n+q1)
log0_cnt = log1_cnt = 0
used = [-1] * n

def delete(pl, bi, pr, flag, log_cnt):
    b_pl, b_bi, b_pr = B_val[pl], B_val[bi] ,B_val[pr]
    if flag == 0:
    	log0[log_cnt] = pl*2+1
    	log0[log_cnt+1] = LR[pl*2+1]
    	log0[log_cnt+2] = pr*2
    	log0[log_cnt+3] = LR[pr*2]
    	log0[log_cnt+4] = b_pl^b_bi
    	log0[log_cnt+5] = b_bi^b_pr
    	log0[log_cnt+6] = b_pl^b_pr
    elif flag == 1:
    	log1[log_cnt] = pl*2+1
    	log1[log_cnt+1] = LR[pl*2+1]
    	log1[log_cnt+2] = pr*2
    	log1[log_cnt+3] = LR[pr*2]
    	log1[log_cnt+4] = b_pl^b_bi
    	log1[log_cnt+5] = b_bi^b_pr
    	log1[log_cnt+6] = b_pl^b_pr
    LR[pl*2+1], LR[pr*2] = pr, pl
    ms.dis(b_pl^b_bi)
    ms.dis(b_bi^b_pr)
    ms.add(b_pl^b_pr)
    return

q1l = 0
while q1l < q1:
    ni = n
    q1r = min(q1l+q1_, q1+1)
    for qi in range(q1l, min(q1r, q1)):
        bi = B_idx[n+qi]
        idx = B_use.index(bi)
        pl = pr = -1
        pl = B_use[idx-1]
        pr = B_use[idx]
        LR[pl*2+1], LR[pr*2] = bi, bi
        LR[bi*2], LR[bi*2+1] = pl, pr
        ms.add(B_val[pl]^B_val[bi])
        ms.add(B_val[bi]^B_val[pr])
        ms.dis(B_val[pl]^B_val[pr])
        B_use.add(bi)
    while Q2 and Q2[-1][0] < q1r:
        q1_now, ni_now, ans_idx = Q2.pop()
        for i in range(ni-1, ni_now-1, -1):
            bi = A[i]
            pl, pr = LR[bi*2], LR[bi*2+1]
            delete(pl, bi, pr, 0, log0_cnt)
            log0_cnt += 7
        ni = ni_now
        for qi in range(min(q1r, q1)-1, q1l-1, -1):
            i, x = Q1[qi]
            if qi < q1_now and i < ni and used[i] != ans_idx:
                used[i] = ans_idx
                bi = A[i]
            else:
                bi = B_idx[n+qi]
            pl, pr = LR[bi*2], LR[bi*2+1]
            delete(pl, bi, pr, 1, log1_cnt)
            log1_cnt += 7
        Ans[ans_idx] = ms.min()
        while log1_cnt:
            log1_cnt -= 7
            LR[log1[log1_cnt]] = log1[log1_cnt+1]
            LR[log1[log1_cnt+2]] = log1[log1_cnt+3]
            ms.add(log1[log1_cnt+4])
            ms.add(log1[log1_cnt+5])
            ms.dis(log1[log1_cnt+6])
    while log0_cnt:
        log0_cnt -= 7
        LR[log0[log0_cnt]] = log0[log0_cnt+1]
        LR[log0[log0_cnt+2]] = log0[log0_cnt+3]
        ms.add(log0[log0_cnt+4])
        ms.add(log0[log0_cnt+5])
        ms.dis(log0[log0_cnt+6])
    for qi in range(q1l, min(q1r, q1)):
        i, x = Q1[qi]
        bi = A[i]
        pl, pr = LR[bi*2], LR[bi*2+1]
        delete(pl, bi, pr, 2, 0)
        B_use.discard(bi)
        A[i] = B_idx[n+qi]
    q1l = q1r

print(*Ans, sep="\n")
0