結果
問題 | No.1068 #いろいろな色 / Red and Blue and more various colors (Hard) |
ユーザー | None |
提出日時 | 2021-05-12 18:06:04 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 22,125 bytes |
コンパイル時間 | 357 ms |
コンパイル使用メモリ | 82,652 KB |
実行使用メモリ | 298,364 KB |
最終ジャッジ日時 | 2024-09-23 02:50:47 |
合計ジャッジ時間 | 13,719 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 47 ms
63,616 KB |
testcase_01 | AC | 47 ms
58,240 KB |
testcase_02 | AC | 46 ms
58,624 KB |
testcase_03 | AC | 1,037 ms
93,924 KB |
testcase_04 | AC | 647 ms
85,516 KB |
testcase_05 | AC | 812 ms
88,704 KB |
testcase_06 | AC | 487 ms
81,536 KB |
testcase_07 | AC | 409 ms
80,428 KB |
testcase_08 | AC | 721 ms
87,040 KB |
testcase_09 | AC | 934 ms
91,140 KB |
testcase_10 | AC | 254 ms
78,592 KB |
testcase_11 | AC | 384 ms
79,988 KB |
testcase_12 | AC | 211 ms
77,784 KB |
testcase_13 | TLE | - |
testcase_14 | -- | - |
testcase_15 | -- | - |
testcase_16 | -- | - |
testcase_17 | -- | - |
testcase_18 | -- | - |
testcase_19 | -- | - |
testcase_20 | -- | - |
testcase_21 | -- | - |
testcase_22 | -- | - |
testcase_23 | -- | - |
testcase_24 | -- | - |
testcase_25 | -- | - |
testcase_26 | -- | - |
testcase_27 | -- | - |
testcase_28 | -- | - |
testcase_29 | -- | - |
testcase_30 | -- | - |
testcase_31 | -- | - |
ソースコード
MOD=998244353 sum_e=( 911660635,509520358,369330050,332049552,983190778,123842337,238493703,975955924,603855026,856644456,131300601,842657263, 730768835,942482514,806263778,151565301,510815449,503497456,743006876,741047443,56250497) sum_ie=( 86583718,372528824,373294451,645684063,112220581,692852209,155456985,797128860,90816748,860285882,927414960,354738543, 109331171,293255632,535113200,308540755,121186627,608385704,438932459,359477183,824071951) def mod_sqrt(a): """ Find x s.t. x^2=a (MOD) O(log a) """ a%=MOD if a<2: return a k=(MOD-1)//2 if pow(a,k,MOD)!=1: return -1 b=1 while pow(b,k,MOD)==1: b+=1 m,e=MOD-1,0 while m%2==0: m>>=1 e+=1 x=pow(a,(m-1)//2,MOD) y=a*x*x%MOD x*=a x%=MOD z=pow(b,m,MOD) while y!=1: j,t=0,y while t!=1: j+=1 t*=t t%=MOD z=pow(z,1<<(e-j-1),MOD) x*=z x%=MOD z*=z z%=MOD y*=z y%=MOD e=j return x def mod_inv(a): """ O(log a) """ a %= MOD if a == 0: return 0 s, t = MOD, a m0, m1 = 0, 1 while t: u = s // t s -= t * u m0 -= m1 * u s, t = t, s m0, m1 = m1, m0 if m0 < 0: m0 += MOD // s return m0 fac_=[1,1] finv_=[1,1] inv_=[1,1] def fac(i): while i>=len(fac_): fac_.append(fac_[-1]*len(fac_)%MOD) return fac_[i] def finv(i): while i>=len(inv_): inv_.append((-inv_[MOD%len(inv_)])*(MOD//len(inv_))%MOD) while i>=len(finv_): finv_.append(finv_[-1]*inv_[len(finv_)]%MOD) return finv_[i] def inv(i): while i>=len(inv_): inv_.append((-inv_[MOD%len(inv_)])*(MOD//len(inv_))%MOD) return inv_[i] def butterfly(A): n=len(A) h=(n-1).bit_length() for ph in range(1,h+1): w=1<<(ph-1) p=1<<(h-ph) now=1 for s in range(w): offset=s<<(h-ph+1) for i in range(p): l=A[i+offset] r=A[i+offset+p]*now A[i+offset]=(l+r)%MOD A[i+offset+p]=(l-r)%MOD now*=sum_e[(~s&-~s).bit_length()-1] now%=MOD def butterfly_inv(A): n=len(A) h=(n-1).bit_length() for ph in range(h,0,-1): w=1<<(ph-1) p=1<<(h-ph) inow=1 for s in range(w): offset=s<<(h-ph+1) for i in range(p): l=A[i+offset] r=A[i+offset+p] A[i+offset]=(l+r)%MOD A[i+offset+p]=(MOD+l-r)*inow%MOD inow*=sum_ie[(~s&-~s).bit_length()-1] inow%=MOD iz=mod_inv(n) for i in range(n): A[i]*=iz A[i]%=MOD def convolution(_A,_B): """ C_i = Sum(A_j * B_(i-j) for j=0,...,len(A)+len(B)-2) O(N log N) """ A=_A.copy() B=_B.copy() n=len(A) m=len(B) if not n or not m: return [] if min(n,m)<=60: if n<m: n,m=m,n A,B=B,A res=[0]*(n+m-1) for i in range(n): for j in range(m): res[i+j]+=A[i]*B[j] res[i+j]%=MOD return res z=1<<(n+m-2).bit_length() A+=[0]*(z-n) B+=[0]*(z-m) butterfly(A) butterfly(B) for i in range(z): A[i]*=B[i] A[i]%=MOD butterfly_inv(A) return A[:n+m-1] def autocorrelation(_A): A=_A.copy() n=len(A) if not n: return [] if n<=60: res=[0]*(n+n-1) for i in range(n): for j in range(n): res[i+j]+=A[i]*A[j] res[i+j]%=MOD return res z=1<<(n+n-2).bit_length() A+=[0]*(z-n) butterfly(A) for i in range(z): A[i]*=A[i] A[i]%=MOD butterfly_inv(A) return A[:n+n-1] class FormalPowerSeries: """ f(z) = a_0 x^0 + a_1 x^1 + ... poly = [a0, a1, ...] """ def __init__(self,poly=[]): self.poly=[p%MOD for p in poly] def __getitem__(self,key): if isinstance(key,slice): res=self.poly[key] return FormalPowerSeries(res) else: if key<0: raise IndexError("list index out of range") if key>=len(self.poly): return 0 else: return self.poly[key] def __setitem__(self,key,value): if key<0: raise IndexError("list index out of range") if key>=len(self.poly): self.poly+=[0]*(key-len(self.poly)+1) self.poly[key]=value%MOD def __len__(self): return len(self.poly) def __str__(self): return str(self.poly) def __iter__(self): for p in self.poly: yield p def __pos__(self): return self def __neg__(self): return self.times(-1) def __add__(self,other): if other.__class__==FormalPowerSeries: s=len(self) t=len(other) n=min(s,t) res=[self[i]+other[i] for i in range(n)] if s>=t: res+=self.poly[t:] else: res+=other.poly[s:] return FormalPowerSeries(res) else: return self+FormalPowerSeries([other]) def __radd__(self,other): return self+other def __sub__(self,other): return self+(-other) def __rsub__(self,other): return (-self)+other def __mul__(self,other): """ convolution O(N log N) """ if other.__class__==FormalPowerSeries: res=convolution(self.poly,other.poly) return FormalPowerSeries(res) else: return self.times(other) def __rmul__(self,other): return self.times(other) def __lshift__(self,other): """ f(z) <- f(z)z^d """ return FormalPowerSeries(([0]*other)+self.poly) def __rshift__(self,other): """ f(z) <- f(z)/z^d """ return self[other:] def __truediv__(self,other): if other.__class__==FormalPowerSeries: return (self*other.inv()) else: return self*mod_inv(other) def __rtruediv__(self,other): return other*self.inv() #P,Qを多項式として見たときのPをQで割った商を求める def __floordiv__(self,other): if other.__class__==FormalPowerSeries: if len(self)<len(other): return FormalPowerSeries() else: m=len(self)-len(other)+1 res=(self[-1:-m-1:-1]*other[::-1].inv(m))[m-1::-1] return res else: return self*mod_inv(other) def __rfloordiv__(self,other): return other*self.inv() def __mod__(self,other): if len(self)<len(other): return self[:] else: res=self[:len(other)-1]-((self//other)*other)[:len(other)-1] return res def __pow__(self,n,deg=-1): """ O(N log N) """ if deg==-1: deg=len(self)-1 m=abs(n) for d,p in enumerate(self.poly): if p: break else: return FormalPowerSeries() if d*m>=len(self): return FormalPowerSeries() G=self[d:] G=((G*mod_inv(p)).log()*m).exp()*pow(p,m,MOD) res=FormalPowerSeries([0]*(d*m)+G.poly) if n>=0: return res[:deg+1] else: return res.inv()[:deg+1] def resize(self,size): if len(self)>=size: return self[:size] else: return FormalPowerSeries(self.poly+[0]*(size-len(self))) def shrink(self): while self.poly and self.poly[-1]==0: self.poly.pop() def times(self,n): n%=MOD res=[p*n for p in self.poly] return FormalPowerSeries(res) def square(self): res=autocorrelation(self.poly) return FormalPowerSeries(res) def inv(self,deg=-1): """ Find g(z) s.t. f(z)g(z)=1 O(N log N) """ if deg==-1: deg=len(self)-1 r=mod_inv(self[0]) m=1 T=[0]*(deg+1) T[0]=r res=FormalPowerSeries(T) while m<=deg: F=[0]*(2*m) G=[0]*(2*m) for j in range(min(len(self),2*m)): F[j]=self[j] for j in range(m): G[j]=res[j] butterfly(F) butterfly(G) for j in range(2*m): F[j]*=G[j] F[j]%=MOD butterfly_inv(F) for j in range(m): F[j]=0 butterfly(F) for j in range(2*m): F[j]*=G[j] F[j]%=MOD butterfly_inv(F) for j in range(m,min(2*m,deg+1)): res[j]=-F[j] m<<=1 return res def differentiate(self,deg=-1): """ O(N) """ if deg==-1: deg=len(self)-2 res=FormalPowerSeries([0]*(deg+1)) for i in range(1,min(len(self),deg+2)): res[i-1]=self[i]*i return res def integrate(self,deg=-1): """ O(N) """ if deg==-1: deg=len(self) res=FormalPowerSeries([0]*(deg+1)) for i in range(min(len(self),deg)): res[i+1]=self[i]*inv(i+1) return res def log(self,deg=-1): """ O(N log N) """ if deg==-1: deg=len(self)-1 return (self.differentiate()*self.inv(deg-1))[:deg].integrate() def exp(self,deg=-1): if deg==-1: deg=len(self)-1 T=[0]*(deg+1) T[0]=1 #T:res^{-1} res=FormalPowerSeries(T) m=1 F=[1] while m<=deg: G=T[:m] butterfly(G) FG=[F[i]*G[i]%MOD for i in range(m)] butterfly_inv(FG) FG[0]-=1 delta=[0]*(2*m) for i in range(m): delta[i+m]=FG[i] eps=[0]*(2*m) if m==1: DF=[] else: DF=res.differentiate(m-2).poly DF.append(0) butterfly(DF) for i in range(m): DF[i]*=G[i] DF[i]%=MOD butterfly_inv(DF) for i in range(m-1): eps[i]=self[i+1]*(i+1)%MOD eps[i+m]=DF[i]-eps[i] eps[m-1]=DF[m-1] butterfly(delta) DH=[0]*(2*m) for i in range(m-1): DH[i]=eps[i] butterfly(DH) for i in range(2*m): delta[i]*=DH[i] delta[i]%=MOD butterfly_inv(delta) for i in range(m,2*m): eps[i]-=delta[i] eps[i]%=MOD for i in range(2*m-1,0,-1): eps[i]=(eps[i-1]*inv(i)-self[i])%MOD eps[0]=-self[0] butterfly(eps) for i in range(m): DH[i]=res[i] DH[i+m]=0 butterfly(DH) for i in range(2*m): eps[i]*=DH[i] eps[i]%=MOD butterfly_inv(eps) for i in range(m,min(2*m,deg+1)): res[i]=-eps[i] if 2*m>deg: break F=[0]*(2*m) G=[0]*(2*m) for i in range(2*m): F[i]=res[i] for i in range(m): G[i]=T[i] butterfly(F) butterfly(G) P=[F[i]*G[i]%MOD for i in range(2*m)] butterfly_inv(P) for i in range(m): P[i]=0 butterfly(P) for i in range(2*m): P[i]*=G[i] P[i]%=MOD butterfly_inv(P) for i in range(m,2*m): T[i]=-P[i] m<<=1 return res def sqrt(self,deg=-1): if deg==-1: deg=len(self)-1 if len(self)==0: return FormalPowerSeries() if self[0]==0: for d in range(1,len(self)): if self[d]: if d&1: return -1 if deg<d//2: break res=self[d:].sqrt(deg-d//2) if res==-1: return -1 res=res<<(d//2) return res return FormalPowerSeries() sqr=mod_sqrt(self[0]) if sqr==-1: return -1 T=[0]*(deg+1) T[0]=sqr res=FormalPowerSeries(T) T[0]=mod_inv(sqr) #T:res^{-1} m=1 two_inv=(MOD+1)//2 F=[sqr] while m<=deg: for i in range(m): F[i]*=F[i] F[i]%=MOD butterfly_inv(F) delta=[0]*(2*m) for i in range(m): delta[i+m]=F[i]-self[i]-self[i+m] butterfly(delta) G=[0]*(2*m) for i in range(m): G[i]=T[i] butterfly(G) for i in range(2*m): delta[i]*=G[i] delta[i]%=MOD butterfly_inv(delta) for i in range(m,min(2*m,deg+1)): res[i]=-delta[i]*two_inv if 2*m>deg: break F=res.poly[:2*m] butterfly(F) eps=[F[i]*G[i]%MOD for i in range(2*m)] butterfly_inv(eps) for i in range(m): eps[i]=0 butterfly(eps) for i in range(2*m): eps[i]*=G[i] eps[i]%=MOD butterfly_inv(eps) for i in range(m,2*m): T[i]=-eps[i] m<<=1 return res def multipoint_evaluation(self,P): """ return [f(P_0), f(P_1), f(P_2),...] """ m=len(P) size=1<<(m-1).bit_length() G=[FormalPowerSeries([1]) for _ in range(2*size)] for i in range(m): G[size+i]=FormalPowerSeries([-P[i],1]) for i in range(size-1,0,-1): G[i]=G[2*i]*G[2*i+1] G[1]=self%G[1] for i in range(2,size+m): G[i]=G[i>>1]%G[i] res=[G[i][0] for i in range(size,size+m)] return res def taylor_shift(self,a): """ f(z+a) """ a%=MOD n=len(self) t=1 F=self[:] G=FormalPowerSeries([0]*n) for i in range(n): F[i]*=fac(i) for i in range(n): G[i]=t*finv(i) t=t*a%MOD res=(F*G[::-1])[n-1:] for i in range(n): res[i]*=finv(i) return res def composition(self,g,deg=-1): """ f(g(z)) O( (N log N)^3/2 ) """ if deg==-1: deg=len(self)-1 k=int(deg**0.5+1) d=(deg+k)//k X=[FormalPowerSeries([1])] for i in range(k): X.append((X[i]*g)[:deg+1]) Y=[FormalPowerSeries([0]*(deg+1)) for _ in range(k)] for i in range(k): for j in range(d): if i*d+j>deg: break for t in range(deg+1): if t>=len(X[j]): break Y[i][t]+=X[j][t]*self[i*d+j] res=FormalPowerSeries([0]*(deg+1)) Z=FormalPowerSeries([1]) for i in range(k): Y[i]=(Y[i]*Z)[:deg+1] for j in range(len(Y[i])): res[j]+=Y[i][j] Z=(Z*X[d])[:deg+1] return res def poly_coef(Q,P,n): """ [x^n]P/Qを求める(deg(Q) > deg(P)) 計算量: log(n) """ if type(P)==FormalPowerSeries: P=P.poly if type(Q)==FormalPowerSeries: Q=Q.poly m=1<<(len(Q)-1).bit_length() P=P+[0]*(2*m-len(P)) Q=Q+[0]*(2*m-len(Q)) while n: R=[0]*(2*m) for i,q in enumerate(Q): R[i]=(1-2*(i&1))*q butterfly(P) butterfly(Q) butterfly(R) for i,r in enumerate(R): P[i]*=r P[i]%=MOD Q[i]*=r Q[i]%=MOD butterfly_inv(P) butterfly_inv(Q) if n&1: for i in range(m): P[i]=P[2*i+1] else: for i in range(m): P[i]=P[2*i] for i in range(m): Q[i]=Q[2*i] for i in range(m,2*m): P[i]=0 Q[i]=0 n>>=1 return P[0] def subset_sum(A,limit): """ #sum=k for k in range(0,K+1) 計算量: O(N + limit) """ C=[0]*(limit+1) for a in A: C[a]+=1 res=FormalPowerSeries([0]*(limit+1)) for i in range(1,limit+1): for k in range(1,limit//i+1): j=i*k res[j]+=((k&1)*2-1)*C[i]*inv(k) return res.exp(limit).poly def partition_function(n): res=FormalPowerSeries([0]*(n+1)) res[0]=1 for k in range(1,n+1): k1=k*(3*k+1)//2 k2=k*(3*k-1)//2 if k2>n: break if k1<=n: res[k1]+=1-(k&1)*2 if k2<=n: res[k2]+=1-(k&1)*2 return res.inv().poly def bernoulli_number(n): n+=1 Q=FormalPowerSeries([finv(i+1) for i in range(n)]).inv(n-1) res=[v*fac(i)%MOD for i,v in enumerate(Q.poly)] return res def stirling_first(n): P=[] a=n while a: if a&1: P.append(1) P.append(0) a>>=1 res=FormalPowerSeries([1]) t=0 for x in P[::-1]: if x: res*=FormalPowerSeries([-t,1]) t+=1 else: res*=res.taylor_shift(-t) t*=2 return res.poly def stirling_second(n): F=FormalPowerSeries([0]*(n+1)) G=FormalPowerSeries([0]*(n+1)) for i in range(n+1): F[i]=pow(i,n,MOD)*finv(i) G[i]=(1-(i&1)*2)*finv(i) return (F*G)[:n+1].poly def polynominal_interpolation(X,Y): n=len(X) size=1<<(n-1).bit_length() M=[FormalPowerSeries([1]) for _ in range(2*size)] G=[0]*(2*size) for i in range(n): M[size+i]=FormalPowerSeries([-X[i],1]) for i in range(size-1,0,-1): M[i]=M[2*i]*M[2*i+1] G[1]=M[1].differentiate()%M[1] for i in range(2,size+n): G[i]=G[i>>1]%M[i] for i in range(n): G[size+i]=FormalPowerSeries([Y[i]*mod_inv(G[size+i][0])]) for i in range(size-1,0,-1): G[i]=G[2*i]*M[2*i+1]+G[2*i+1]*M[2*i] return G[1][:n] class Mat2: def __init__(self,a00=FormalPowerSeries([1]),a01=FormalPowerSeries(), a10=FormalPowerSeries(),a11=FormalPowerSeries([1])): self.a00=a00 self.a01=a01 self.a10=a10 self.a11=a11 def __mul__(self,other): if type(other)==Mat2: A00=self.a00*other.a00+self.a01*other.a10 A01=self.a00*other.a01+self.a01*other.a11 A10=self.a10*other.a00+self.a11*other.a10 A11=self.a10*other.a01+self.a11*other.a11 A00.shrink() A01.shrink() A10.shrink() A11.shrink() return Mat2(A00,A01,A10,A11) else: b0=self.a00*other[0]+self.a01*other[1] b1=self.a10*other[0]+self.a11*other[1] b0.shrink() b1.shrink() return (b0,b1) def __imul__(self,other): A00=self.a00*other.a00+self.a01*other.a10 A01=self.a00*other.a01+self.a01*other.a11 A10=self.a10*other.a00+self.a11*other.a10 A11=self.a10*other.a01+self.a11*other.a11 A00.shrink() A01.shrink() A10.shrink() A11.shrink() self.a00=A00 self.a01=A01 self.a10=A10 self.a11=A11 def _inner_naive_gcd(m,p): quo=p[0]//p[1] rem=p[0]-p[1]*quo b10=m.a00-m.a10*quo b11=m.a01-m.a11*quo rem.shrink() b10.shrink() b11.shrink() b10,m.a10=m.a10,b10 b11,m.a11=m.a11,b11 b10,m.a00=m.a00,b10 b11,m.a01=m.a01,b11 return (p[1],rem) def _inner_half_gcd(p): n,m=len(p[0]),len(p[1]) k=(n+1)//2 if m<=k: return Mat2() m1=_inner_half_gcd((p[0]>>k,p[1]>>k)) p=m1*p if len(p[1])<=k: return m1 p=_inner_naive_gcd(m1,p) if len(p[1])<=k: return m1 l=len(p[0])-1 j=2*k-1 p=(p[0]>>j,p[1]>>j) return _inner_half_gcd(p)*m1 def _inner_poly_gcd(a,b): p=(a,b) p[0].shrink() p[1].shrink() n,m=len(p[0]),len(p[1]) if n<m: mat=_inner_poly_gcd(p[1],p[0]) mat.a00,mat.a01=mat.a01,mat.a00 mat.a10,mat.a11=mat.a11,mat.a10 return mat res=Mat2() while 1: m1=_inner_half_gcd(p) p=m1*p if len(p[1])==0: return m1*res p=_inner_naive_gcd(m1,p) if len(p[1])==0: return m1*res res=m1*res def poly_gcd(a,b): p=(a,b) m=_inner_poly_gcd(a,b) p=m*p if len(p[0]): coef=mod_inv(p[0][-1]) p[0]*=coef return p[0] def poly_inv(f,g): p=(f,g) m=_inner_poly_gcd(f,g) _gcd=(m*p)[0] if len(_gcd)!=1: return -1 x=(FormalPowerSeries([1]),g) res=((m*x)[0]%g)*mod_inv(_gcd[0]) res.shrink() return res def linear_recurrence(A,C,n): """ Find solution to recursion relation a_n = C_0 a_(n-1) + C_1 a_(n-2) + .... C_(K-1) a_(n-K) with A=[a_0,a_1,...,a_K-1] C=[C_0,C_1,...,C_K-1] O(K logK logn) """ K=len(A) Q=[0]*(K+1) Q[0]=1 for i,c in enumerate(C,1): Q[i]=-c A=convolution(A,Q)[:K] res=poly_coef(Q,A,n) return res ####################################################### import sys input = sys.stdin.readline N,Q=map(int, input().split()) F=FormalPowerSeries([1]+[0]*N) A=list(map(int, input().split())) for a in A: f=FormalPowerSeries([a-1,1]) F*=f B=list(map(int, input().split())) for b in B: print(F[b])