結果
| 問題 |
No.876 Range Compress Query
|
| コンテスト | |
| ユーザー |
👑 SPD_9X2
|
| 提出日時 | 2020-09-03 15:29:38 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 461 ms / 2,000 ms |
| コード長 | 3,081 bytes |
| コンパイル時間 | 137 ms |
| コンパイル使用メモリ | 82,368 KB |
| 実行使用メモリ | 93,824 KB |
| 最終ジャッジ日時 | 2024-11-23 18:53:34 |
| 合計ジャッジ時間 | 5,322 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 18 |
ソースコード
"""
https://yukicoder.me/problems/no/876
種類数(分断されている場合別カウント)を計算 + 区間更新を処理
xを足して、同じ、違うになる可能性があるのは端のみ
区間内の句切れ目の数+1が答え
和のセグ木を保存しておけばおk
"""
#0-indexed , 半開区間[a,b)
#calc変更で演算変更
class SegTree:
def __init__(self,N,first):
self.NO = 2**(N-1).bit_length()
self.First = first
self.data = [first] * (2*self.NO)
def calc(self,l,r):
return l+r
def update(self,ind,x):
ind += self.NO - 1
self.data[ind] = x
while ind >= 0:
ind = (ind - 1)//2
self.data[ind] = self.calc(self.data[2*ind+1],self.data[2*ind+2])
def query(self,l,r):
L = l + self.NO
R = r + self.NO
s = self.First
while L < R:
if R & 1:
R -= 1
s = self.calc(s , self.data[R-1])
if L & 1:
s = self.calc(s , self.data[L-1])
L += 1
L >>= 1
R >>= 1
return s
class RangeBIT:
def __init__(self,N,indexed):
self.bit1 = [0] * (N+2)
self.bit2 = [0] * (N+2)
self.mode = indexed
def bitadd(self,a,w,bit): #aにwを加える(1-origin)
x = a
while x <= (len(bit)-1):
bit[x] += w
x += x & (-1 * x)
def bitsum(self,a,bit): #ind 1~aまでの和を求める
ret = 0
x = a
while x > 0:
ret += bit[x]
x -= x & (-1 * x)
return ret
def add(self,l,r,w): #半開区間[l,r)にwを加える
l = l + (1-self.mode)
r = r + (1-self.mode)
self.bitadd(l,-1*w*l,self.bit1)
self.bitadd(r,w*r,self.bit1)
self.bitadd(l,w,self.bit2)
self.bitadd(r,-1*w,self.bit2)
def sum(self,l,r): #半開区間[l,r)の区間和
l = l + (1-self.mode)
r = r + (1-self.mode)
#print ("s",l,r)
ret = self.bitsum(r,self.bit1) + r * self.bitsum(r,self.bit2)
ret -= self.bitsum(l,self.bit1) + l * self.bitsum(l,self.bit2)
return ret
from sys import stdin
N,Q = map(int,stdin.readline().split())
a = list(map(int,stdin.readline().split()))
ST = RangeBIT(N,0)
RT = SegTree(N,0)
for i in range(N):
ST.add(i,i+1,a[i])
#print (RT.data)
for i in range(N-1):
if a[i] != a[i+1]:
RT.update(i,1)
#print (RT.data)
for loop in range(Q):
q = stdin.readline()
if q[0] == "1":
s,l,r,x = map(int,q.split())
l -= 1 ; r -= 1
if l - 1 >= 0:
RT.update(l-1,0)
RT.update(r,0)
ST.add(l,r+1,x)
if l-1 >= 0 and ST.sum(l-1,l) != ST.sum(l,l+1):
RT.update(l-1,1)
if r+1 < N and ST.sum(r,r+1) != ST.sum(r+1,r+2):
RT.update(r,1)
else:
s,l,r = map(int,q.split())
l -= 1 ; r -= 1
if l == r:
print (1)
else:
print (RT.query(l,r)+1)
SPD_9X2