結果
| 問題 | 
                            No.1863 Xor Sum 2...?
                             | 
                    
| コンテスト | |
| ユーザー | 
                            👑  SPD_9X2
                         | 
                    
| 提出日時 | 2022-03-04 22:19:47 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 138 ms / 2,000 ms | 
| コード長 | 1,879 bytes | 
| コンパイル時間 | 226 ms | 
| コンパイル使用メモリ | 82,176 KB | 
| 実行使用メモリ | 100,992 KB | 
| 最終ジャッジ日時 | 2024-07-18 20:49:35 | 
| 合計ジャッジ時間 | 4,527 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge2 / judge3 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 28 | 
ソースコード
"""
1863:
右辺は0 or 1
0の場合、
Aで許されているのは、
各bitが2回以上登場しない場合である。
→ 尺取り法で行ける
1の場合は
2のbitが2回出ると
損失が2になるのでまずい
というか、存在しない
(Lnex,i]の区間で、XORが0になる区間の個数を求める
"""
import sys
from sys import stdin
from collections import deque
#0-indexed , 半開区間[a,b)
#calc変更で演算変更
class SegTree:
    def __init__(self,N,first):
        self.NO = 2**(N-1).bit_length()
        self.First = first
        self.data = [first] * (2*self.NO)
    def calc(self,l,r):
        return l+r
    def update(self,ind,x):
        ind += self.NO - 1
        self.data[ind] = x
        while ind >= 0:
            ind = (ind - 1)//2
            self.data[ind] = self.calc(self.data[2*ind+1],self.data[2*ind+2])
    def query(self,l,r):
        L = l + self.NO
        R = r + self.NO
        s = self.First
        while L < R:
            if R & 1:
                R -= 1
                s = self.calc(s , self.data[R-1])
            if L & 1:
                s = self.calc(s , self.data[L-1])
                L += 1
            L >>= 1
            R >>= 1
        return s
    def get(self , ind):
        ind += self.NO - 1
        return self.data[ind]
N = int(stdin.readline())
A = list(map(int,stdin.readline().split()))
B = [0] + list(map(int,stdin.readline().split()))
for i in range(N):
    B[i + 1] ^= B[i]
ST = SegTree(N+1,0)
for i in range(N+1):
    ST.update(i,B[i])
Lnex = 0
S = 0
X = 0
ans = 0
#print (B)
for i in range(1,N+1):
    S += A[i-1]
    X ^= A[i-1]
    while S != X:
        S -= A[Lnex]
        X ^= A[Lnex]
        Lnex += 1
    
    nsum = ST.query(Lnex,i)
    if ST.get(i) == 0:
        nsum = ( i-Lnex ) - nsum
    #print (Lnex,i,nsum)
    ans += nsum
print (ans)
            
            
            
        
            
SPD_9X2