結果
問題 | No.2361 Many String Compare Queries |
ユーザー | とりゐ |
提出日時 | 2023-01-14 18:10:36 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,424 ms / 2,500 ms |
コード長 | 7,057 bytes |
コンパイル時間 | 289 ms |
コンパイル使用メモリ | 82,728 KB |
実行使用メモリ | 175,364 KB |
最終ジャッジ日時 | 2024-06-30 20:40:43 |
合計ジャッジ時間 | 10,466 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 40 ms
55,188 KB |
testcase_01 | AC | 40 ms
55,116 KB |
testcase_02 | AC | 39 ms
55,552 KB |
testcase_03 | AC | 43 ms
55,896 KB |
testcase_04 | AC | 41 ms
56,464 KB |
testcase_05 | AC | 40 ms
54,816 KB |
testcase_06 | AC | 41 ms
55,228 KB |
testcase_07 | AC | 63 ms
70,292 KB |
testcase_08 | AC | 1,385 ms
161,252 KB |
testcase_09 | AC | 1,390 ms
160,272 KB |
testcase_10 | AC | 1,424 ms
161,096 KB |
testcase_11 | AC | 847 ms
146,884 KB |
testcase_12 | AC | 982 ms
154,704 KB |
testcase_13 | AC | 947 ms
152,968 KB |
testcase_14 | AC | 851 ms
172,040 KB |
testcase_15 | AC | 929 ms
175,364 KB |
ソースコード
from sys import stdin input=lambda :stdin.readline()[:-1] class segtree(): def __init__(self,init,func,ide): self.n=len(init) self.func=func self.ide=ide self.size=1<<(self.n-1).bit_length() self.tree=[self.ide for i in range(2*self.size)] for i in range(self.n): self.tree[self.size+i]=init[i] for i in range(self.size-1,0,-1): self.tree[i]=self.func(self.tree[2*i], self.tree[2*i|1]) def update(self,k,x): k+=self.size self.tree[k]=x k>>=1 while k: self.tree[k]=self.func(self.tree[2*k],self.tree[k*2|1]) k>>=1 def get(self,i): return self.tree[i+self.size] def add(self,k,x): x+=self.get(k) self.update(k,x) def query(self,l,r): l+=self.size r+=self.size l_res=self.ide r_res=self.ide while l<r: if l&1: l_res=self.func(l_res,self.tree[l]) l+=1 if r&1: r-=1 r_res=self.func(self.tree[r],r_res) l>>=1 r>>=1 return self.func(l_res,r_res) def debug(self,s=10): print([self.get(i) for i in range(min(self.n,s))]) class string: def sa_is(s,upper): n=len(s) if n==0: return [] if n==1: return [0] if n==2: if (s[0]<s[1]): return [0,1] else: return [1,0] sa=[0]*n ls=[0]*n for i in range(n-2,-1,-1): ls[i]=ls[i+1] if (s[i]==s[i+1]) else (s[i]<s[i+1]) sum_l=[0]*(upper+1) sum_s=[0]*(upper+1) for i in range(n): if not(ls[i]): sum_s[s[i]]+=1 else: sum_l[s[i]+1]+=1 for i in range(upper+1): sum_s[i]+=sum_l[i] if i<upper: sum_l[i+1]+=sum_s[i] def induce(lms): for i in range(n): sa[i]=-1 buf=sum_s[:] for d in lms: if d==n: continue sa[buf[s[d]]]=d buf[s[d]]+=1 buf=sum_l[:] sa[buf[s[n-1]]]=n-1 buf[s[n-1]]+=1 for i in range(n): v=sa[i] if v>=1 and not(ls[v-1]): sa[buf[s[v-1]]]=v-1 buf[s[v-1]]+=1 buf=sum_l[:] for i in range(n-1,-1,-1): v=sa[i] if v>=1 and ls[v-1]: buf[s[v-1]+1]-=1 sa[buf[s[v-1]+1]]=v-1 lms_map=[-1]*(n+1) m=0 for i in range(1,n): if not(ls[i-1]) and ls[i]: lms_map[i]=m m+=1 lms=[] for i in range(1,n): if not(ls[i-1]) and ls[i]: lms.append(i) induce(lms) if m: sorted_lms=[] for v in sa: if lms_map[v]!=-1: sorted_lms.append(v) rec_s=[0]*m rec_upper=0 rec_s[lms_map[sorted_lms[0]]]=0 for i in range(1,m): l=sorted_lms[i-1] r=sorted_lms[i] end_l=lms[lms_map[l]+1] if (lms_map[l]+1<m) else n end_r=lms[lms_map[r]+1] if (lms_map[r]+1<m) else n same=True if end_l-l!=end_r-r: same=False else: while(l<end_l): if s[l]!=s[r]: break l+=1 r+=1 if (l==n) or (s[l]!=s[r]): same=False if not(same): rec_upper+=1 rec_s[lms_map[sorted_lms[i]]]=rec_upper rec_sa=string.sa_is(rec_s,rec_upper) for i in range(m): sorted_lms[i]=lms[rec_sa[i]] induce(sorted_lms) return sa def suffix_array_upper(s,upper): assert 0<=upper for d in s: assert 0<=d and d<=upper return string.sa_is(s,upper) def suffix_array(s): n=len(s) if type(s)==str: s2=[ord(i) for i in s] return string.sa_is(s2,255) else: idx=list(range(n)) idx.sort(key=lambda x:s[x]) s2=[0]*n now=0 for i in range(n): if (i& s[idx[i-1]]!=s[idx[i]]): now+=1 s2[idx[i]]=now return string.sa_is(s2,now) def lcp_array(s,sa): n=len(s) assert n>=1 rnk=[0]*n for i in range(n): rnk[sa[i]]=i lcp=[0]*(n-1) h=0 for i in range(n): if h>0: h-=1 if rnk[i]==0: continue j=sa[rnk[i]-1] while(j+h<n and i+h<n): if s[j+h]!=s[i+h]: break h+=1 lcp[rnk[i]-1]=h return lcp def z_algorithm(s): n=len(s) if n==0: return [] z=[0]*n i=1;j=0 while(i<n): z[i]=0 if (j+z[j]<=i) else min(j+z[j]-i,z[i-j]) while((i+z[i]<n) and (s[z[i]]==s[i+z[i]])): z[i]+=1 if (j+z[j]<i+z[i]): j=i i+=1 z[0]=n return z class SparseTable: def __init__(self,init,func,e): n=len(init) self.e=e self.func=func size=0 while (1<<size)<=n: size+=1 self.size=size self.table=[e]*(size*(1<<size)) for i in range(n): self.table[i]=init[i] for i in range(1,size): for j in range((1<<size)-(1<<i)+1): self.table[(i<<size)+j]=func(self.table[((i-1)<<size)+j],self.table[((i-1)<<size)+j+(1<<(i-1))]) def query(self,l,r): if l==r: return self.e s=(r-l).bit_length()-1 return self.func(self.table[(s<<self.size)+l],self.table[(s<<self.size)+r-(1<<s)]) n,q=map(int,input().split()) S=input() SA=string.suffix_array(S) LCP=string.lcp_array(S,SA) p=[0]*n for i in range(n): p[SA[i]]=i ST=SparseTable(LCP,min,10**9) query=[[] for i in range(n)] for i in range(q): L,R=map(lambda x:int(x)-1,input().split()) d=R-L+1 ng,ok=-1,p[L] while ok-ng>1: mid=(ng+ok)//2 if ST.query(mid,p[L])>=d: ok=mid else: ng=mid query[ok].append((d,i)) cum=[0]*n for i in range(1,n): cum[i]=cum[i-1]+(n-SA[i-1]) ans=[0]*q seg1=segtree([0]*(n+1),lambda x,y:x+y,0) seg2=segtree([0]*(n+1),lambda x,y:x+y,0) stc=[] for i in range(n-1,-1,-1): if i!=n-1: res=0 while stc and stc[-1][0]>=LCP[i]: res+=stc[-1][1] x=stc.pop()[0] seg1.update(x,0) seg2.update(x,0) if res!=0 and LCP[i]!=0: seg1.update(LCP[i],LCP[i]*res) seg2.update(LCP[i],res) stc.append((LCP[i],res)) seg1.add(n-SA[i],n-SA[i]) seg2.add(n-SA[i],1) stc.append((n-SA[i],1)) for d,idx in query[i]: ans[idx]+=cum[i]+seg1.query(0,d)+seg2.query(d,n+1)*(d-1) print(*ans,sep='\n')