結果
| 問題 |
No.1321 塗るめた
|
| コンテスト | |
| ユーザー |
vwxyz
|
| 提出日時 | 2023-04-27 01:53:22 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 16,124 bytes |
| コンパイル時間 | 256 ms |
| コンパイル使用メモリ | 82,428 KB |
| 実行使用メモリ | 282,308 KB |
| 最終ジャッジ日時 | 2024-11-16 09:40:31 |
| 合計ジャッジ時間 | 49,481 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 35 TLE * 10 |
ソースコード
from collections import deque
import math
import sys
readline=sys.stdin.readline
mod = 998244353
class FFT:
def __init__(self, mod=998244353):
FFT.mod = mod
self.make_info(mod)
def make_info(self, mod):
g = self.primitive_root_constexpr()
m = mod - 1
rank2 = (m & -m).bit_length() - 1
root = [0] * (rank2 + 1)
iroot = [0] * (rank2 + 1)
rate2 = [0] * (rank2 + 1)
irate2 = [0] * (rank2 + 1)
rate3 = [0] * (rank2)
irate3 = [0] * (rank2)
root[rank2] = pow(g, (mod - 1) >> rank2, mod)
iroot[rank2] = pow(root[rank2], mod - 2, mod)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i + 1] * root[i + 1] % mod
iroot[i] = iroot[i + 1] * iroot[i + 1] % mod
prod = 1
iprod = 1
for i in range(1, rank2):
rate2[i] = root[i + 1] * prod % mod
irate2[i] = iroot[i + 1] * iprod % mod
prod = prod * iroot[i + 1] % mod
iprod = iprod * root[i + 1] % mod
prod = 1
iprod = 1
for i in range(1, rank2 - 1):
rate3[i] = root[i + 2] * prod % mod
irate3[i] = iroot[i + 2] * iprod % mod
prod = prod * iroot[i + 2] % mod
iprod = iprod * root[i + 2] % mod
self.IMAG = rate2[1]
self.IIMAG = irate2[1]
self.rate2 = rate2
self.irate2 = irate2
self.rate3 = rate3
self.irate3 = irate3
def primitive_root_constexpr(self):
if FFT.mod == 998244353:
return 3
elif FFT.mod == 200003:
return 2
elif FFT.mod == 167772161:
return 3
elif FFT.mod == 469762049:
return 3
elif FFT.mod == 754974721:
return 11
divs = [0] * 20
divs[0] = 2
cnt = 1
x = (FFT.mod - 1) // 2
while x % 2 == 0:
x //= 2
i = 3
while i * i <= x:
if x % i == 0:
divs[cnt] = i
cnt += 1
while x % i == 0:
x //= i
i += 2
if x > 1:
divs[cnt] = x
cnt += 1
g = 2
while 1:
ok = True
for i in range(cnt):
if pow(g, (FFT.mod - 1) // divs[i], FFT.mod) == 1:
ok = False
break
if ok:
return g
g += 1
def butterfly(self, A):
n = len(A)
h = (n - 1).bit_length()
le = 0
while le < h:
if h - le == 1:
p = 1 << (h - le - 1)
rot = 1
for s in range(1 << le):
offset = s << (h - le)
for i in range(p):
l = A[i + offset]
r = A[i + offset + p] * rot
A[i + offset] = (l + r) % FFT.mod
A[i + offset + p] = (l - r) % FFT.mod
rot *= self.rate2[(~s & -~s).bit_length()]
rot %= FFT.mod
le += 1
else:
p = 1 << (h - le - 2)
rot = 1
for s in range(1 << le):
rot2 = rot * rot % FFT.mod
rot3 = rot2 * rot % FFT.mod
offset = s << (h - le)
for i in range(p):
a0 = A[i + offset]
a1 = A[i + offset + p] * rot
a2 = A[i + offset + p * 2] * rot2
a3 = A[i + offset + p * 3] * rot3
a1na3imag = (a1 - a3) % FFT.mod * self.IMAG
A[i + offset] = (a0 + a2 + a1 + a3) % FFT.mod
A[i + offset + p] = (a0 + a2 - a1 - a3) % FFT.mod
A[i + offset + p * 2] = (a0 - a2 + a1na3imag) % FFT.mod
A[i + offset + p * 3] = (a0 - a2 - a1na3imag) % FFT.mod
rot *= self.rate3[(~s & -~s).bit_length()]
rot %= FFT.mod
le += 2
def butterfly_inv(self, A):
n = len(A)
h = (n - 1).bit_length()
le = h
while le:
if le == 1:
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 1)):
offset = s << (h - le + 1)
for i in range(p):
l = A[i + offset]
r = A[i + offset + p]
A[i + offset] = (l + r) % FFT.mod
A[i + offset + p] = (l - r) * irot % FFT.mod
irot *= self.irate2[(~s & -~s).bit_length()]
irot %= FFT.mod
le -= 1
else:
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 2)):
irot2 = irot * irot % FFT.mod
irot3 = irot2 * irot % FFT.mod
offset = s << (h - le + 2)
for i in range(p):
a0 = A[i + offset]
a1 = A[i + offset + p]
a2 = A[i + offset + p * 2]
a3 = A[i + offset + p * 3]
a2na3iimag = (a2 - a3) * self.IIMAG % FFT.mod
A[i + offset] = (a0 + a1 + a2 + a3) % FFT.mod
A[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % FFT.mod
A[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % FFT.mod
A[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % FFT.mod
irot *= self.irate3[(~s & -~s).bit_length()]
irot %= FFT.mod
le -= 2
def convolve(self, A, B):
n = len(A)
m = len(B)
if min(n, m) <= 60:
C = [0] * (n + m - 1)
for i in range(n):
if i % 8 == 0:
for j in range(m):
C[i + j] += A[i] * B[j]
C[i + j] %= FFT.mod
else:
for j in range(m):
C[i + j] += A[i] * B[j]
return [c % FFT.mod for c in C]
A = A[:]
B = B[:]
z = 1 << (n + m - 2).bit_length()
A += [0] * (z - n)
B += [0] * (z - m)
self.butterfly(A)
self.butterfly(B)
for i in range(z):
A[i] *= B[i]
A[i] %= FFT.mod
self.butterfly_inv(A)
A = A[:n + m - 1]
iz = pow(z, FFT.mod - 2, FFT.mod)
return [a * iz % FFT.mod for a in A]
class FPS:
fact = [1]
invfact = [1]
mod = None
def __init__(self, data, mod=998244353):
if FPS.mod is None:
FPS.mod = mod
FPS.fft = FFT(mod)
if type(data) == int:
self.f = [data]
else:
self.f = data[:]
def __len__(self):
return len(self.f)
def __getitem__(self, i):
return self.f[i]
def __add__(self, other):
if len(self) < len(other):
other, self = self, other
for i in range(len(other)):
self.f[i] += other[i]
if self.f[i] >= FPS.mod:
self.f[i] -= FPS.mod
return self
def __iadd__(self, other):
return self.__add__(other)
def __sub__(self, other):
self.f += [0] * (len(other) - len(self))
for i in range(len(other)):
self.f[i] -= other[i]
if self.f[i] < 0:
self.f[i] += FPS.mod
return self
def __isub__(self, other):
return self.__sub__(other)
def __mul__(self, other):
if type(other) == int:
f = [other * x % FPS.mod for x in self.f]
return FPS(f)
f = FPS.fft.convolve(self.f[:], other.f[:])
return FPS(f)
def __imul__(self, other):
if type(other) == int:
self.f = [other * x % FPS.mod for x in self.f]
return self
self.f = FPS.fft.convolve(self.f, other.f[:])
return self
def inv(self, deg=None):
if deg is None:
deg = len(self)
g = FPS(pow(self[0], FPS.mod - 2, FPS.mod))
l = 1
while l < deg:
tmp = g * 2
l *= 2
tmp2 = FPS(self.f[:l]) * (g * g)
g = tmp - tmp2
del g.f[l:]
del g.f[deg:]
return g
def differential(self):
return FPS([x * i % FPS.mod for i, x in enumerate(self.f[1:], 1)])
def extend_fact(self, l):
l1 = len(FPS.fact)
l += 1
if l1 <= l:
FPS.fact += [0] * (l - l1)
FPS.invfact += [0] * (l - l1)
for i in range(l1, l):
FPS.fact[i] = FPS.fact[i - 1] * i % FPS.mod
FPS.invfact[l - 1] = pow(FPS.fact[l - 1], FPS.mod - 2, FPS.mod)
for i in range(l - 1, l1, -1):
FPS.invfact[i - 1] = FPS.invfact[i] * i % FPS.mod
def integral(self):
self.extend_fact(len(self))
return FPS([0] + [x * (FPS.fact[i] * FPS.invfact[i + 1] % FPS.mod) % FPS.mod for i, x in enumerate(self.f)])
def log(self, deg=None):
if deg is None:
deg = len(self)
tmp = self.differential() * self.inv(deg=deg)
del tmp.f[deg:]
tmp = tmp.integral()
del tmp.f[deg:]
return tmp
def exp(self, deg=None):
if deg is None:
deg = len(self)
g = FPS(1)
l = 1
while l < deg * 2:
l *= 2
log = FPS(1) - g.log(deg=l) + FPS(self.f[:l])
del log.f[l:]
g *= log
del g.f[l:]
del g.f[deg:]
return g
def __pow__(self, k, deg=None):
if k == 0:
if deg is None:
ret = [0] * len(self)
else:
ret = [0] * deg
ret[0] = 1
return FPS(ret)
if deg is None:
deg = len(self)
i = 0
p = None
for i in range(deg):
if self[i] != 0:
a = self[i]
p = i
break
if p is None:
if deg is not None:
return FPS([0] * deg)
else:
return FPS(0)
elif deg is not None and p * k >= deg:
return FPS([0] * deg)
inv = pow(a, FPS.mod - 2, FPS.mod)
tmp = FPS([x * inv % FPS.mod for x in self.f[p:]])
tmp = tmp.log(deg=deg)
if deg is not None:
del tmp.f[deg:]
tmp *= k
tmp = tmp.exp(deg=deg)
tmp = [0] * (p * k) + tmp.f[:deg - p * k]
times = pow(a, k, FPS.mod)
return FPS([x * times % FPS.mod for x in tmp])
def __ipow__(self, k):
return self.__pow__(k)
def cipolla(self, a):
if FPS.mod == 2:
return a
elif a == 0:
return 0
elif pow(a, (FPS.mod - 1) // 2, FPS.mod) != 1:
return -1
b = 0
while pow((b * b + FPS.mod - a) % FPS.mod, (FPS.mod - 1) // 2, FPS.mod) == 1:
b += 1
base = b * b + FPS.mod - a
def multi(a0, b0, a1, b1):
return (a0 * a1 + (b0 * b1 % FPS.mod) * base) % FPS.mod, (a0 * b1 + b0 * a1) % FPS.mod
def pow_(a, b, n):
if n == 0:
return 1, 0
a_, b_ = pow_(*multi(a, b, a, b), n // 2)
if n % 2 == 1:
a_, b_ = multi(a_, b_, a, b)
return a_, b_
return pow_(b, 1, (FPS.mod + 1) // 2)[0]
def sqrt(self, deg=None):
if deg is None:
deg = len(self)
if len(self) == 0:
return FPS([0] * deg)
if self[0] == 0:
for i in range(1, len(self)):
if self[i] != 0:
if i & 1:
return FPS([])
if deg <= i // 2:
break
ret = FPS(self.f[i:]).sqrt(deg - i // 2)
if len(ret) == 0:
return FPS([])
ret.f = [0] * (i // 2) + ret.f
if len(ret) < deg:
ret.f += [0] * (deg - len(ret))
return ret
return FPS([0] * deg)
sq = self.cipolla(self[0])
if sq == -1:
return FPS([])
inv2 = (FPS.mod + 1) // 2
g = FPS([sq])
l = 1
while l < deg:
l *= 2
tmp = FPS(self.f[:l]) * g.inv(deg=l)
g += tmp
g *= inv2
del g.f[deg:]
return g
def taylorshift(self, a):
deg = len(self)
f = self.f[:]
self.extend_fact(deg)
for i in range(deg):
f[i] *= FPS.fact[i]
f[i] %= FPS.mod
f = f[::-1]
g = [0] * deg
g[0] = 1
for i in range(1, deg):
g[i] = (g[i - 1] * a % FPS.mod) * (FPS.fact[i - 1] * FPS.invfact[i] % FPS.mod) % FPS.mod
f = FPS.fft.convolve(f, g)
del f[deg:]
f = f[::-1]
for i in range(deg):
f[i] *= FPS.invfact[i]
f[i] %= FPS.mod
return FPS(f)
def Extended_Euclid(n,m):
stack=[]
while m:
stack.append((n,m))
n,m=m,n%m
if n>=0:
x,y=1,0
else:
x,y=-1,0
for i in range(len(stack)-1,-1,-1):
n,m=stack[i]
x,y=y,x-(n//m)*y
return x,y
class MOD:
def __init__(self,p,e=None):
self.p=p
self.e=e
if self.e==None:
self.mod=self.p
else:
self.mod=self.p**self.e
def Pow(self,a,n):
a%=self.mod
if n>=0:
return pow(a,n,self.mod)
else:
assert math.gcd(a,self.mod)==1
x=Extended_Euclid(a,self.mod)[0]
return pow(x,-n,self.mod)
def Build_Fact(self,N):
assert N>=0
self.factorial=[1]
if self.e==None:
for i in range(1,N+1):
self.factorial.append(self.factorial[-1]*i%self.mod)
else:
self.cnt=[0]*(N+1)
for i in range(1,N+1):
self.cnt[i]=self.cnt[i-1]
ii=i
while ii%self.p==0:
ii//=self.p
self.cnt[i]+=1
self.factorial.append(self.factorial[-1]*ii%self.mod)
self.factorial_inve=[None]*(N+1)
self.factorial_inve[-1]=self.Pow(self.factorial[-1],-1)
for i in range(N-1,-1,-1):
ii=i+1
while ii%self.p==0:
ii//=self.p
self.factorial_inve[i]=(self.factorial_inve[i+1]*ii)%self.mod
def Fact(self,N):
if N<0:
return 0
retu=self.factorial[N]
if self.e!=None and self.cnt[N]:
retu*=pow(self.p,self.cnt[N],self.mod)%self.mod
retu%=self.mod
return retu
def Fact_Inve(self,N):
if self.e!=None and self.cnt[N]:
return None
return self.factorial_inve[N]
def Comb(self,N,K,divisible_count=False):
if K<0 or K>N:
return 0
retu=self.factorial[N]*self.factorial_inve[K]%self.mod*self.factorial_inve[N-K]%self.mod
if self.e!=None:
cnt=self.cnt[N]-self.cnt[N-K]-self.cnt[K]
if divisible_count:
return retu,cnt
else:
retu*=pow(self.p,cnt,self.mod)
retu%=self.mod
return retu
N,M,K=map(int,readline().split())
mod=998244353
MD=MOD(mod)
MD.Build_Fact(N)
poly=[None]*(N-K+1)
for i in range(N-K+1):
poly[i]=MD.Fact_Inve(i+1)
P=FPS(poly)
P=P.log(deg=N-K+1)
for i in range(N-K+1):
P.f[i]*=K
P.f[i]%=mod
P=P.exp(deg=N-K+1)
ans=0
for n in range(K,N+1):
ans+=P[n-K]*MD.Pow(M,N-n)%mod*MD.Fact_Inve(N-n)%mod
ans*=MD.Comb(M,K)*MD.Fact(N)%mod
ans%=mod
print(ans)
vwxyz