結果

問題 No.3239 Omnibus
ユーザー titia
提出日時 2025-08-31 01:27:35
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 664 ms / 10,000 ms
コード長 4,447 bytes
コンパイル時間 372 ms
コンパイル使用メモリ 82,244 KB
実行使用メモリ 122,724 KB
最終ジャッジ日時 2025-08-31 01:27:59
合計ジャッジ時間 16,980 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.readline

# class化
class Bit_indexed_tree():
    def __init__(self, LEN):
        self.BIT = [0]*(LEN+1) # 1-indexedなtree. 配列BITの長さはLEN+1にしていることに注意。
        self.LEN = LEN

    def update(self,v,w): # index vにwを加える
        while v<=self.LEN:
            self.BIT[v]+=w
            v+=(v&(-v)) # v&(-v)で、最も下の立っているビット. 自分を含む大きなノードへ. たとえばv=3→v=4

    def getvalue(self,v): # [1,v]の区間の和を求める
        ANS=0
        while v!=0:
            ANS+=self.BIT[v]
            v-=(v&(-v)) # 自分より小さい自分の和を構成するノードへ. たとえばv=14→v=12へ
        return ANS

    def bisect_on_BIT(self,x): # [1,ind]の和がはじめてx以上になるindexを探す

        if x<=0:
            return 0
        
        ANS=0
        h=1<<((self.LEN).bit_length()-1) # LEN以下の最小の2ベキ
        while h>0:
            if ANS+h<=self.LEN and self.BIT[ANS+h]<x:
                x-=self.BIT[ANS+h]
                ANS+=h
            h//=2

        return ANS+1 # LENまでの和がx未満のとき, LEN+1を返すことに注意

N,Q=map(int,input().split())
S=list(input().strip())

LIST=[set() for i in range(26*26*26)]
LIST2=[[] for i in range(26*26*26)]
LIST3=[[] for i in range(26*26*26)]

def tm(i):
    x=ord(S[i])-97
    y=ord(S[i+1])-97
    z=ord(S[i+2])-97

    a=x+y*26+z*26*26

    return a

def tm2(s,i):
    x=ord(s[i])-97
    y=ord(s[i+1])-97
    z=ord(s[i+2])-97

    a=x+y*26+z*26*26

    return a

for i in range(N-2):
    a=tm(i)
    LIST[a].add(i)
const0=10000
for i in range(26*26*26):
    if len(LIST[i])>=const0:
        LIST2[i]=Bit_indexed_tree(N+100)
        LIST3[i]=Bit_indexed_tree(N+100)

        for x in LIST[i]:
            LIST2[i].update(x+1,x)
            LIST3[i].update(x+1,1)
            


for tests in range(Q):
    L=input().split()
    if len(L)==3:
        i,x=L[1],L[2]
        i=int(i)-1
        
        if i-2>=0:
            a=tm(i-2)
            if LIST2[a]:
                LIST2[a].update(i-2+1,-(i-2))
                LIST3[a].update(i-2+1,-1)
            else:
                LIST[a].remove(i-2)
        if i-1>=0 and i+1<len(S):
            a=tm(i-1)
            if LIST2[a]:
                LIST2[a].update(i-1+1,-(i-1))
                LIST3[a].update(i-1+1,-1)
            else:
                LIST[a].remove(i-1)
        if i+2<len(S):
            a=tm(i)
            if LIST2[a]:
                LIST2[a].update(i+1,-(i))
                LIST3[a].update(i+1,-1)
            else:
                LIST[a].remove(i)

        S[i]=x

        if i-2>=0:
            a=tm(i-2)
            if LIST2[a]:
                LIST2[a].update(i-2+1,i-2)
                LIST3[a].update(i-2+1,1)
            else:
                LIST[a].add(i-2)

                if len(LIST[a])>=const0:
                    LIST2[a]=Bit_indexed_tree(N+100)

                    for x in LIST[a]:
                        LIST2[a].update(x+1,x)
                        LIST3[a].update(x+1,1)
        if i-1>=0 and i+1<len(S):
            a=tm(i-1)
            if LIST2[a]:
                LIST2[a].update(i-1+1,i-1)
                LIST3[a].update(i-1+1,1)
            else:
                LIST[a].add(i-1)
                if len(LIST[a])>=const0:
                    LIST2[a]=Bit_indexed_tree(N+100)

                    for x in LIST[a]:
                        LIST2[a].update(x+1,x)
                        LIST3[a].update(x+1,1)
        if i+2<len(S):
            a=tm(i)
            if LIST2[a]:
                LIST2[a].update(i+1,i)
                LIST3[a].update(i+1,1)
            else:
                LIST[a].add(i)
                if len(LIST[a])>=const0:
                    LIST2[a]=Bit_indexed_tree(N+100)

                    for x in LIST[a]:
                        LIST2[a].update(x+1,x)
                        LIST3[a].update(x+1,1)

    else:
        l,r,s=L[1],L[2],L[3]

        l=int(l)-1
        r=int(r)-1

        a=tm2(s,0)

        ANS=0

        if LIST2[a]:
            cc=LIST2[a].getvalue(r-2+1)-LIST2[a].getvalue(l)
            dd=LIST3[a].getvalue(r-2+1)-LIST3[a].getvalue(l)

            ANS=cc-(l-1)*dd

            print(ANS)

        else:
            for x in LIST[a]:
                if l<=x<=r-2:
                    ANS+=x-l+1

            print(ANS)

        
        

        
        



0