結果
| 問題 |
No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-05-12 18:52:28 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 2,039 ms / 3,500 ms |
| コード長 | 22,401 bytes |
| コンパイル時間 | 205 ms |
| コンパイル使用メモリ | 82,076 KB |
| 実行使用メモリ | 232,632 KB |
| 最終ジャッジ日時 | 2024-09-23 04:19:20 |
| 合計ジャッジ時間 | 39,983 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 29 |
ソースコード
MOD=998244353
sum_e=(
911660635,509520358,369330050,332049552,983190778,123842337,238493703,975955924,603855026,856644456,131300601,842657263,
730768835,942482514,806263778,151565301,510815449,503497456,743006876,741047443,56250497)
sum_ie=(
86583718,372528824,373294451,645684063,112220581,692852209,155456985,797128860,90816748,860285882,927414960,354738543,
109331171,293255632,535113200,308540755,121186627,608385704,438932459,359477183,824071951)
def mod_sqrt(a):
"""
Find x s.t. x^2=a (MOD)
O(log a)
"""
a%=MOD
if a<2:
return a
k=(MOD-1)//2
if pow(a,k,MOD)!=1:
return -1
b=1
while pow(b,k,MOD)==1:
b+=1
m,e=MOD-1,0
while m%2==0:
m>>=1
e+=1
x=pow(a,(m-1)//2,MOD)
y=a*x*x%MOD
x*=a
x%=MOD
z=pow(b,m,MOD)
while y!=1:
j,t=0,y
while t!=1:
j+=1
t*=t
t%=MOD
z=pow(z,1<<(e-j-1),MOD)
x*=z
x%=MOD
z*=z
z%=MOD
y*=z
y%=MOD
e=j
return x
def mod_inv(a):
"""
O(log a)
"""
a %= MOD
if a == 0:
return 0
s, t = MOD, a
m0, m1 = 0, 1
while t:
u = s // t
s -= t * u
m0 -= m1 * u
s, t = t, s
m0, m1 = m1, m0
if m0 < 0:
m0 += MOD // s
return m0
fac_=[1,1]
finv_=[1,1]
inv_=[1,1]
def fac(i):
while i>=len(fac_):
fac_.append(fac_[-1]*len(fac_)%MOD)
return fac_[i]
def finv(i):
while i>=len(inv_):
inv_.append((-inv_[MOD%len(inv_)])*(MOD//len(inv_))%MOD)
while i>=len(finv_):
finv_.append(finv_[-1]*inv_[len(finv_)]%MOD)
return finv_[i]
def inv(i):
while i>=len(inv_):
inv_.append((-inv_[MOD%len(inv_)])*(MOD//len(inv_))%MOD)
return inv_[i]
def butterfly(A):
n=len(A)
h=(n-1).bit_length()
for ph in range(1,h+1):
w=1<<(ph-1)
p=1<<(h-ph)
now=1
for s in range(w):
offset=s<<(h-ph+1)
for i in range(p):
l=A[i+offset]
r=A[i+offset+p]*now
A[i+offset]=(l+r)%MOD
A[i+offset+p]=(l-r)%MOD
now*=sum_e[(~s&-~s).bit_length()-1]
now%=MOD
def butterfly_inv(A):
n=len(A)
h=(n-1).bit_length()
for ph in range(h,0,-1):
w=1<<(ph-1)
p=1<<(h-ph)
inow=1
for s in range(w):
offset=s<<(h-ph+1)
for i in range(p):
l=A[i+offset]
r=A[i+offset+p]
A[i+offset]=(l+r)%MOD
A[i+offset+p]=(MOD+l-r)*inow%MOD
inow*=sum_ie[(~s&-~s).bit_length()-1]
inow%=MOD
iz=mod_inv(n)
for i in range(n):
A[i]*=iz
A[i]%=MOD
def convolution(_A,_B):
"""
C_i = Sum(A_j * B_(i-j) for j=0,...,len(A)+len(B)-2)
O(N log N)
"""
A=_A.copy()
B=_B.copy()
n=len(A)
m=len(B)
if not n or not m:
return []
if min(n,m)<=60:
if n<m:
n,m=m,n
A,B=B,A
res=[0]*(n+m-1)
for i in range(n):
for j in range(m):
res[i+j]+=A[i]*B[j]
res[i+j]%=MOD
return res
z=1<<(n+m-2).bit_length()
A+=[0]*(z-n)
B+=[0]*(z-m)
butterfly(A)
butterfly(B)
for i in range(z):
A[i]*=B[i]
A[i]%=MOD
butterfly_inv(A)
return A[:n+m-1]
def autocorrelation(_A):
A=_A.copy()
n=len(A)
if not n:
return []
if n<=60:
res=[0]*(n+n-1)
for i in range(n):
for j in range(n):
res[i+j]+=A[i]*A[j]
res[i+j]%=MOD
return res
z=1<<(n+n-2).bit_length()
A+=[0]*(z-n)
butterfly(A)
for i in range(z):
A[i]*=A[i]
A[i]%=MOD
butterfly_inv(A)
return A[:n+n-1]
class FormalPowerSeries:
"""
f(z) = a_0 x^0 + a_1 x^1 + ...
poly = [a0, a1, ...]
"""
def __init__(self,poly=[]):
self.poly=[p%MOD for p in poly]
def __getitem__(self,key):
if isinstance(key,slice):
res=self.poly[key]
return FormalPowerSeries(res)
else:
if key<0:
raise IndexError("list index out of range")
if key>=len(self.poly):
return 0
else:
return self.poly[key]
def __setitem__(self,key,value):
if key<0:
raise IndexError("list index out of range")
if key>=len(self.poly):
self.poly+=[0]*(key-len(self.poly)+1)
self.poly[key]=value%MOD
def __len__(self):
return len(self.poly)
def __str__(self):
return str(self.poly)
def __iter__(self):
for p in self.poly:
yield p
def __pos__(self):
return self
def __neg__(self):
return self.times(-1)
def __add__(self,other):
if other.__class__==FormalPowerSeries:
s=len(self)
t=len(other)
n=min(s,t)
res=[self[i]+other[i] for i in range(n)]
if s>=t:
res+=self.poly[t:]
else:
res+=other.poly[s:]
return FormalPowerSeries(res)
else:
return self+FormalPowerSeries([other])
def __radd__(self,other):
return self+other
def __sub__(self,other):
return self+(-other)
def __rsub__(self,other):
return (-self)+other
def __mul__(self,other):
"""
convolution
O(N log N)
"""
if other.__class__==FormalPowerSeries:
res=convolution(self.poly,other.poly)
return FormalPowerSeries(res)
else:
return self.times(other)
def __rmul__(self,other):
return self.times(other)
def __lshift__(self,other):
"""
f(z) <- f(z)z^d
"""
return FormalPowerSeries(([0]*other)+self.poly)
def __rshift__(self,other):
"""
f(z) <- f(z)/z^d
"""
return self[other:]
def __truediv__(self,other):
if other.__class__==FormalPowerSeries:
return (self*other.inv())
else:
return self*mod_inv(other)
def __rtruediv__(self,other):
return other*self.inv()
#P,Qを多項式として見たときのPをQで割った商を求める
def __floordiv__(self,other):
if other.__class__==FormalPowerSeries:
if len(self)<len(other):
return FormalPowerSeries()
else:
m=len(self)-len(other)+1
res=(self[-1:-m-1:-1]*other[::-1].inv(m))[m-1::-1]
return res
else:
return self*mod_inv(other)
def __rfloordiv__(self,other):
return other*self.inv()
def __mod__(self,other):
if len(self)<len(other):
return self[:]
else:
res=self[:len(other)-1]-((self//other)*other)[:len(other)-1]
return res
def __pow__(self,n,deg=-1):
"""
O(N log N)
"""
if deg==-1:
deg=len(self)-1
m=abs(n)
for d,p in enumerate(self.poly):
if p:
break
else:
return FormalPowerSeries()
if d*m>=len(self):
return FormalPowerSeries()
G=self[d:]
G=((G*mod_inv(p)).log()*m).exp()*pow(p,m,MOD)
res=FormalPowerSeries([0]*(d*m)+G.poly)
if n>=0:
return res[:deg+1]
else:
return res.inv()[:deg+1]
def resize(self,size):
if len(self)>=size:
return self[:size]
else:
return FormalPowerSeries(self.poly+[0]*(size-len(self)))
def shrink(self):
while self.poly and self.poly[-1]==0:
self.poly.pop()
def times(self,n):
n%=MOD
res=[p*n for p in self.poly]
return FormalPowerSeries(res)
def square(self):
res=autocorrelation(self.poly)
return FormalPowerSeries(res)
def inv(self,deg=-1):
"""
Find g(z) s.t. f(z)g(z)=1
O(N log N)
"""
if deg==-1:
deg=len(self)-1
r=mod_inv(self[0])
m=1
T=[0]*(deg+1)
T[0]=r
res=FormalPowerSeries(T)
while m<=deg:
F=[0]*(2*m)
G=[0]*(2*m)
for j in range(min(len(self),2*m)):
F[j]=self[j]
for j in range(m):
G[j]=res[j]
butterfly(F)
butterfly(G)
for j in range(2*m):
F[j]*=G[j]
F[j]%=MOD
butterfly_inv(F)
for j in range(m):
F[j]=0
butterfly(F)
for j in range(2*m):
F[j]*=G[j]
F[j]%=MOD
butterfly_inv(F)
for j in range(m,min(2*m,deg+1)):
res[j]=-F[j]
m<<=1
return res
def differentiate(self,deg=-1):
"""
O(N)
"""
if deg==-1:
deg=len(self)-2
res=FormalPowerSeries([0]*(deg+1))
for i in range(1,min(len(self),deg+2)):
res[i-1]=self[i]*i
return res
def integrate(self,deg=-1):
"""
O(N)
"""
if deg==-1:
deg=len(self)
res=FormalPowerSeries([0]*(deg+1))
for i in range(min(len(self),deg)):
res[i+1]=self[i]*inv(i+1)
return res
def log(self,deg=-1):
"""
O(N log N)
"""
if deg==-1:
deg=len(self)-1
return (self.differentiate()*self.inv(deg-1))[:deg].integrate()
def exp(self,deg=-1):
if deg==-1:
deg=len(self)-1
T=[0]*(deg+1)
T[0]=1 #T:res^{-1}
res=FormalPowerSeries(T)
m=1
F=[1]
while m<=deg:
G=T[:m]
butterfly(G)
FG=[F[i]*G[i]%MOD for i in range(m)]
butterfly_inv(FG)
FG[0]-=1
delta=[0]*(2*m)
for i in range(m):
delta[i+m]=FG[i]
eps=[0]*(2*m)
if m==1:
DF=[]
else:
DF=res.differentiate(m-2).poly
DF.append(0)
butterfly(DF)
for i in range(m):
DF[i]*=G[i]
DF[i]%=MOD
butterfly_inv(DF)
for i in range(m-1):
eps[i]=self[i+1]*(i+1)%MOD
eps[i+m]=DF[i]-eps[i]
eps[m-1]=DF[m-1]
butterfly(delta)
DH=[0]*(2*m)
for i in range(m-1):
DH[i]=eps[i]
butterfly(DH)
for i in range(2*m):
delta[i]*=DH[i]
delta[i]%=MOD
butterfly_inv(delta)
for i in range(m,2*m):
eps[i]-=delta[i]
eps[i]%=MOD
for i in range(2*m-1,0,-1):
eps[i]=(eps[i-1]*inv(i)-self[i])%MOD
eps[0]=-self[0]
butterfly(eps)
for i in range(m):
DH[i]=res[i]
DH[i+m]=0
butterfly(DH)
for i in range(2*m):
eps[i]*=DH[i]
eps[i]%=MOD
butterfly_inv(eps)
for i in range(m,min(2*m,deg+1)):
res[i]=-eps[i]
if 2*m>deg:
break
F=[0]*(2*m)
G=[0]*(2*m)
for i in range(2*m):
F[i]=res[i]
for i in range(m):
G[i]=T[i]
butterfly(F)
butterfly(G)
P=[F[i]*G[i]%MOD for i in range(2*m)]
butterfly_inv(P)
for i in range(m):
P[i]=0
butterfly(P)
for i in range(2*m):
P[i]*=G[i]
P[i]%=MOD
butterfly_inv(P)
for i in range(m,2*m):
T[i]=-P[i]
m<<=1
return res
def sqrt(self,deg=-1):
if deg==-1:
deg=len(self)-1
if len(self)==0:
return FormalPowerSeries()
if self[0]==0:
for d in range(1,len(self)):
if self[d]:
if d&1:
return -1
if deg<d//2:
break
res=self[d:].sqrt(deg-d//2)
if res==-1:
return -1
res=res<<(d//2)
return res
return FormalPowerSeries()
sqr=mod_sqrt(self[0])
if sqr==-1:
return -1
T=[0]*(deg+1)
T[0]=sqr
res=FormalPowerSeries(T)
T[0]=mod_inv(sqr) #T:res^{-1}
m=1
two_inv=(MOD+1)//2
F=[sqr]
while m<=deg:
for i in range(m):
F[i]*=F[i]
F[i]%=MOD
butterfly_inv(F)
delta=[0]*(2*m)
for i in range(m):
delta[i+m]=F[i]-self[i]-self[i+m]
butterfly(delta)
G=[0]*(2*m)
for i in range(m):
G[i]=T[i]
butterfly(G)
for i in range(2*m):
delta[i]*=G[i]
delta[i]%=MOD
butterfly_inv(delta)
for i in range(m,min(2*m,deg+1)):
res[i]=-delta[i]*two_inv
if 2*m>deg:
break
F=res.poly[:2*m]
butterfly(F)
eps=[F[i]*G[i]%MOD for i in range(2*m)]
butterfly_inv(eps)
for i in range(m):
eps[i]=0
butterfly(eps)
for i in range(2*m):
eps[i]*=G[i]
eps[i]%=MOD
butterfly_inv(eps)
for i in range(m,2*m):
T[i]=-eps[i]
m<<=1
return res
def multipoint_evaluation(self,P):
"""
return [f(P_0), f(P_1), f(P_2),...]
"""
m=len(P)
size=1<<(m-1).bit_length()
G=[FormalPowerSeries([1]) for _ in range(2*size)]
for i in range(m):
G[size+i]=FormalPowerSeries([-P[i],1])
for i in range(size-1,0,-1):
G[i]=G[2*i]*G[2*i+1]
G[1]=self%G[1]
for i in range(2,size+m):
G[i]=G[i>>1]%G[i]
res=[G[i][0] for i in range(size,size+m)]
return res
def taylor_shift(self,a):
"""
f(z+a)
"""
a%=MOD
n=len(self)
t=1
F=self[:]
G=FormalPowerSeries([0]*n)
for i in range(n):
F[i]*=fac(i)
for i in range(n):
G[i]=t*finv(i)
t=t*a%MOD
res=(F*G[::-1])[n-1:]
for i in range(n):
res[i]*=finv(i)
return res
def composition(self,g,deg=-1):
"""
f(g(z))
O( (N log N)^3/2 )
"""
if deg==-1:
deg=len(self)-1
k=int(deg**0.5+1)
d=(deg+k)//k
X=[FormalPowerSeries([1])]
for i in range(k):
X.append((X[i]*g)[:deg+1])
Y=[FormalPowerSeries([0]*(deg+1)) for _ in range(k)]
for i in range(k):
for j in range(d):
if i*d+j>deg:
break
for t in range(deg+1):
if t>=len(X[j]):
break
Y[i][t]+=X[j][t]*self[i*d+j]
res=FormalPowerSeries([0]*(deg+1))
Z=FormalPowerSeries([1])
for i in range(k):
Y[i]=(Y[i]*Z)[:deg+1]
for j in range(len(Y[i])):
res[j]+=Y[i][j]
Z=(Z*X[d])[:deg+1]
return res
def product_all(polys):
if not polys:
return FormalPowerSeries([1])
polys=deque(polys)
for _ in range(len(polys)-1):
f=polys.popleft()
g=polys.popleft()
polys.append(f*g)
return polys[0]
def poly_coef(Q,P,n):
"""
[x^n]P/Qを求める(deg(Q) > deg(P))
計算量: log(n)
"""
if type(P)==FormalPowerSeries:
P=P.poly
if type(Q)==FormalPowerSeries:
Q=Q.poly
m=1<<(len(Q)-1).bit_length()
P=P+[0]*(2*m-len(P))
Q=Q+[0]*(2*m-len(Q))
while n:
R=[0]*(2*m)
for i,q in enumerate(Q):
R[i]=(1-2*(i&1))*q
butterfly(P)
butterfly(Q)
butterfly(R)
for i,r in enumerate(R):
P[i]*=r
P[i]%=MOD
Q[i]*=r
Q[i]%=MOD
butterfly_inv(P)
butterfly_inv(Q)
if n&1:
for i in range(m):
P[i]=P[2*i+1]
else:
for i in range(m):
P[i]=P[2*i]
for i in range(m):
Q[i]=Q[2*i]
for i in range(m,2*m):
P[i]=0
Q[i]=0
n>>=1
return P[0]
def subset_sum(A,limit):
"""
#sum=k for k in range(0,K+1)
計算量: O(N + limit)
"""
C=[0]*(limit+1)
for a in A:
C[a]+=1
res=FormalPowerSeries([0]*(limit+1))
for i in range(1,limit+1):
for k in range(1,limit//i+1):
j=i*k
res[j]+=((k&1)*2-1)*C[i]*inv(k)
return res.exp(limit).poly
def partition_function(n):
res=FormalPowerSeries([0]*(n+1))
res[0]=1
for k in range(1,n+1):
k1=k*(3*k+1)//2
k2=k*(3*k-1)//2
if k2>n:
break
if k1<=n:
res[k1]+=1-(k&1)*2
if k2<=n:
res[k2]+=1-(k&1)*2
return res.inv().poly
def bernoulli_number(n):
n+=1
Q=FormalPowerSeries([finv(i+1) for i in range(n)]).inv(n-1)
res=[v*fac(i)%MOD for i,v in enumerate(Q.poly)]
return res
def stirling_first(n):
P=[]
a=n
while a:
if a&1:
P.append(1)
P.append(0)
a>>=1
res=FormalPowerSeries([1])
t=0
for x in P[::-1]:
if x:
res*=FormalPowerSeries([-t,1])
t+=1
else:
res*=res.taylor_shift(-t)
t*=2
return res.poly
def stirling_second(n):
F=FormalPowerSeries([0]*(n+1))
G=FormalPowerSeries([0]*(n+1))
for i in range(n+1):
F[i]=pow(i,n,MOD)*finv(i)
G[i]=(1-(i&1)*2)*finv(i)
return (F*G)[:n+1].poly
def polynominal_interpolation(X,Y):
n=len(X)
size=1<<(n-1).bit_length()
M=[FormalPowerSeries([1]) for _ in range(2*size)]
G=[0]*(2*size)
for i in range(n):
M[size+i]=FormalPowerSeries([-X[i],1])
for i in range(size-1,0,-1):
M[i]=M[2*i]*M[2*i+1]
G[1]=M[1].differentiate()%M[1]
for i in range(2,size+n):
G[i]=G[i>>1]%M[i]
for i in range(n):
G[size+i]=FormalPowerSeries([Y[i]*mod_inv(G[size+i][0])])
for i in range(size-1,0,-1):
G[i]=G[2*i]*M[2*i+1]+G[2*i+1]*M[2*i]
return G[1][:n]
class Mat2:
def __init__(self,a00=FormalPowerSeries([1]),a01=FormalPowerSeries(),
a10=FormalPowerSeries(),a11=FormalPowerSeries([1])):
self.a00=a00
self.a01=a01
self.a10=a10
self.a11=a11
def __mul__(self,other):
if type(other)==Mat2:
A00=self.a00*other.a00+self.a01*other.a10
A01=self.a00*other.a01+self.a01*other.a11
A10=self.a10*other.a00+self.a11*other.a10
A11=self.a10*other.a01+self.a11*other.a11
A00.shrink()
A01.shrink()
A10.shrink()
A11.shrink()
return Mat2(A00,A01,A10,A11)
else:
b0=self.a00*other[0]+self.a01*other[1]
b1=self.a10*other[0]+self.a11*other[1]
b0.shrink()
b1.shrink()
return (b0,b1)
def __imul__(self,other):
A00=self.a00*other.a00+self.a01*other.a10
A01=self.a00*other.a01+self.a01*other.a11
A10=self.a10*other.a00+self.a11*other.a10
A11=self.a10*other.a01+self.a11*other.a11
A00.shrink()
A01.shrink()
A10.shrink()
A11.shrink()
self.a00=A00
self.a01=A01
self.a10=A10
self.a11=A11
def _inner_naive_gcd(m,p):
quo=p[0]//p[1]
rem=p[0]-p[1]*quo
b10=m.a00-m.a10*quo
b11=m.a01-m.a11*quo
rem.shrink()
b10.shrink()
b11.shrink()
b10,m.a10=m.a10,b10
b11,m.a11=m.a11,b11
b10,m.a00=m.a00,b10
b11,m.a01=m.a01,b11
return (p[1],rem)
def _inner_half_gcd(p):
n,m=len(p[0]),len(p[1])
k=(n+1)//2
if m<=k:
return Mat2()
m1=_inner_half_gcd((p[0]>>k,p[1]>>k))
p=m1*p
if len(p[1])<=k:
return m1
p=_inner_naive_gcd(m1,p)
if len(p[1])<=k:
return m1
l=len(p[0])-1
j=2*k-1
p=(p[0]>>j,p[1]>>j)
return _inner_half_gcd(p)*m1
def _inner_poly_gcd(a,b):
p=(a,b)
p[0].shrink()
p[1].shrink()
n,m=len(p[0]),len(p[1])
if n<m:
mat=_inner_poly_gcd(p[1],p[0])
mat.a00,mat.a01=mat.a01,mat.a00
mat.a10,mat.a11=mat.a11,mat.a10
return mat
res=Mat2()
while 1:
m1=_inner_half_gcd(p)
p=m1*p
if len(p[1])==0:
return m1*res
p=_inner_naive_gcd(m1,p)
if len(p[1])==0:
return m1*res
res=m1*res
def poly_gcd(a,b):
p=(a,b)
m=_inner_poly_gcd(a,b)
p=m*p
if len(p[0]):
coef=mod_inv(p[0][-1])
p[0]*=coef
return p[0]
def poly_inv(f,g):
p=(f,g)
m=_inner_poly_gcd(f,g)
_gcd=(m*p)[0]
if len(_gcd)!=1:
return -1
x=(FormalPowerSeries([1]),g)
res=((m*x)[0]%g)*mod_inv(_gcd[0])
res.shrink()
return res
def linear_recurrence(A,C,n):
"""
Find solution to recursion relation
a_n = C_0 a_(n-1) + C_1 a_(n-2) + .... C_(K-1) a_(n-K)
with
A=[a_0,a_1,...,a_K-1]
C=[C_0,C_1,...,C_K-1]
O(K logK logn)
"""
K=len(A)
Q=[0]*(K+1)
Q[0]=1
for i,c in enumerate(C,1):
Q[i]=-c
A=convolution(A,Q)[:K]
res=poly_coef(Q,A,n)
return res
#######################################################
import sys
input = sys.stdin.readline
from collections import deque
N,Q=map(int, input().split())
A=list(map(int, input().split()))
polys=[]
for a in A:
f=FormalPowerSeries([a-1,1])
polys.append(f)
F=product_all(polys)
B=list(map(int, input().split()))
for b in B:
print(F[b])