from collections import Counter # SA-IS (O(nlogn)) # s: 文字列 def sais(s): uniq = list(set(s)) uniq.sort() return sais_rec(list(map({e: i+1 for i, e in enumerate(uniq)}.__getitem__, s)), len(uniq)) def sais_rec(lst, num): L = len(lst) if L < 2: return lst + [0] lst = lst + [0] L += 1 res = [-1] * L t = [1] * L for i in range(L-2, -1, -1): t[i] = 1 if (lst[i] < lst[i+1] or (lst[i] == lst[i+1] and t[i+1])) else 0 isLMS = [t[i-1] < t[i] for i in range(L)] LMS = [i for i in range(1, L) if t[i-1] < t[i]] LMSn = len(LMS) count = Counter(lst) tmp = 0 cstart = [0]*(num+1) cend = [0]*(num+1) for key in range(num+1): cstart[key] = tmp cend[key] = tmp = tmp + count[key] cc_it = [iter(range(e-1, s-1, -1)) for s, e in zip(cstart, cend)] for e in reversed(LMS): res[next(cc_it[lst[e]])] = e cs_it = [iter(range(s, e)) for s, e in zip(cstart, cend)] ce_it = [iter(range(e-1, s-1, -1)) for s, e in zip(cstart, cend)] for e in res: if e > 0 and not t[e-1]: res[next(cs_it[lst[e-1]])] = e-1 for e in reversed(res): if e > 0 and t[e-1]: res[next(ce_it[lst[e-1]])] = e-1 name = 0; prev = -1 pLMS = {} for e in res: if isLMS[e]: if prev == -1 or lst[e] != lst[prev]: name += 1; prev = e else: for i in range(1, L): if lst[e+i] != lst[prev+i]: name += 1; prev = e break if isLMS[e+i] or isLMS[prev+i]: break pLMS[e] = name-1 if name < LMSn: sublst = [pLMS[e] for e in LMS if e < L-1] ret = sais_rec(sublst, name-1) LMS = list(map(LMS.__getitem__, reversed(ret))) else: LMS = [e for e in reversed(res) if isLMS[e]] res = [-1] * L cc_it = [iter(range(e-1, s-1, -1)) for s, e in zip(cstart, cend)] for e in LMS: res[next(cc_it[lst[e]])] = e cs_it = [iter(range(s, e)) for s, e in zip(cstart, cend)] ce_it = [iter(range(e-1, s-1, -1)) for s, e in zip(cstart, cend)] for e in res: if e > 0 and not t[e-1]: res[next(cs_it[lst[e-1]])] = e-1 for e in reversed(res): if e > 0 and t[e-1]: res[next(ce_it[lst[e-1]])] = e-1 return res # Longest Common Prefix # (文字列s, 文字列長n, Suffix Array)を引数として与える def LCP(s, n, sa): lcp = [-1]*(n+1) rank = [0]*(n+1) for i in range(n+1): rank[sa[i]] = i h = 0 lcp[0] = 0 for i in range(n): j = sa[rank[i] - 1] if h > 0: h -= 1 while j+h < n and i+h < n and s[j+h]==s[i+h]: h += 1 lcp[rank[i] - 1] = h return lcp #https://github.com/shakayami/ACL-for-python/blob/master/convolution.py class FFT(): def primitive_root_constexpr(self,m): if m==2:return 1 if m==167772161:return 3 if m==469762049:return 3 if m==754974721:return 11 if m==998244353:return 3 divs=[0]*20 divs[0]=2 cnt=1 x=(m-1)//2 while(x%2==0):x//=2 i=3 while(i*i<=x): if (x%i==0): divs[cnt]=i cnt+=1 while(x%i==0): x//=i i+=2 if x>1: divs[cnt]=x cnt+=1 g=2 while(1): ok=True for i in range(cnt): if pow(g,(m-1)//divs[i],m)==1: ok=False break if ok: return g g+=1 def bsf(self,x): res=0 while(x%2==0): res+=1 x//=2 return res rank2=0 root=[] iroot=[] rate2=[] irate2=[] rate3=[] irate3=[] def __init__(self,MOD): self.mod=MOD self.g=self.primitive_root_constexpr(self.mod) self.rank2=self.bsf(self.mod-1) self.root=[0 for i in range(self.rank2+1)] self.iroot=[0 for i in range(self.rank2+1)] self.rate2=[0 for i in range(self.rank2)] self.irate2=[0 for i in range(self.rank2)] self.rate3=[0 for i in range(self.rank2-1)] self.irate3=[0 for i in range(self.rank2-1)] self.root[self.rank2]=pow(self.g,(self.mod-1)>>self.rank2,self.mod) self.iroot[self.rank2]=pow(self.root[self.rank2],self.mod-2,self.mod) for i in range(self.rank2-1,-1,-1): self.root[i]=(self.root[i+1]**2)%self.mod self.iroot[i]=(self.iroot[i+1]**2)%self.mod prod=1;iprod=1 for i in range(self.rank2-1): self.rate2[i]=(self.root[i+2]*prod)%self.mod self.irate2[i]=(self.iroot[i+2]*iprod)%self.mod prod=(prod*self.iroot[i+2])%self.mod iprod=(iprod*self.root[i+2])%self.mod prod=1;iprod=1 for i in range(self.rank2-2): self.rate3[i]=(self.root[i+3]*prod)%self.mod self.irate3[i]=(self.iroot[i+3]*iprod)%self.mod prod=(prod*self.iroot[i+3])%self.mod iprod=(iprod*self.root[i+3])%self.mod def butterfly(self,a): n=len(a) h=(n-1).bit_length() LEN=0 while(LEN1: mid=(right+left)//2 if f(mid): left=mid else: right=mid return right def isok(n): pos=sa[n] if S[pos:].upper()=10**8: tmp-=998244353 if M>tmp>=M-K*2: ans+=1 print(ans)