結果

問題 No.1321 塗るめた
ユーザー vwxyzvwxyz
提出日時 2023-04-27 01:42:51
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 32,080 bytes
コンパイル時間 181 ms
コンパイル使用メモリ 82,172 KB
実行使用メモリ 271,076 KB
最終ジャッジ日時 2024-11-16 09:33:22
合計ジャッジ時間 64,525 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 51 ms
60,748 KB
testcase_01 AC 249 ms
78,996 KB
testcase_02 AC 52 ms
60,800 KB
testcase_03 AC 50 ms
59,520 KB
testcase_04 AC 50 ms
60,416 KB
testcase_05 AC 55 ms
59,648 KB
testcase_06 AC 50 ms
60,160 KB
testcase_07 AC 66 ms
70,144 KB
testcase_08 AC 90 ms
76,172 KB
testcase_09 AC 88 ms
76,288 KB
testcase_10 AC 55 ms
60,288 KB
testcase_11 AC 113 ms
77,204 KB
testcase_12 AC 1,251 ms
163,232 KB
testcase_13 TLE -
testcase_14 AC 1,167 ms
158,132 KB
testcase_15 AC 1,221 ms
162,216 KB
testcase_16 AC 1,252 ms
164,144 KB
testcase_17 AC 245 ms
78,860 KB
testcase_18 TLE -
testcase_19 AC 685 ms
112,620 KB
testcase_20 AC 1,214 ms
159,380 KB
testcase_21 TLE -
testcase_22 TLE -
testcase_23 TLE -
testcase_24 TLE -
testcase_25 TLE -
testcase_26 TLE -
testcase_27 TLE -
testcase_28 TLE -
testcase_29 TLE -
testcase_30 TLE -
testcase_31 TLE -
testcase_32 AC 1,244 ms
163,656 KB
testcase_33 AC 1,228 ms
160,168 KB
testcase_34 AC 1,263 ms
164,100 KB
testcase_35 AC 1,228 ms
159,920 KB
testcase_36 TLE -
testcase_37 TLE -
testcase_38 TLE -
testcase_39 TLE -
testcase_40 TLE -
testcase_41 TLE -
testcase_42 AC 49 ms
59,904 KB
testcase_43 AC 1,265 ms
163,272 KB
testcase_44 AC 1,225 ms
159,784 KB
testcase_45 AC 1,261 ms
164,228 KB
testcase_46 AC 437 ms
91,740 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import math
import sys
readline=sys.stdin.readline

mod=998244353
def NTT(polynomial0,polynomial1):
    """
    if len(polynomial0)*len(polynomial1)<=50:
        poly=[0]*(len(polynomial0)+len(polynomial1)-1)
        for i in range(len(polynomial0)):
            for j in range(len(polynomial1)):
                poly[i+j]+=polynomial0[i]*polynomial1[j]%mod
                poly[i+j]%=mod
        return poly
    """
    if mod==998244353:
        prim_root=3
        prim_root_inve=332748118
    else:
        prim_root=Primitive_Root(mod)
        prim_root_inve=MOD(mod).Pow(prim_root,-1)
    def DFT(polynomial,n,inverse=False):
        if inverse:
            for bit in range(1,n+1):
                a=1<<bit-1
                x=pow(prim_root,mod-1>>bit,mod)
                U=[1]
                for _ in range(a):
                    U.append(U[-1]*x%mod)
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t]*U[j])%mod,(polynomial[s]-polynomial[t]*U[j])%mod
            x=pow((mod+1)//2,n,mod)
            for i in range(1<<n):
                polynomial[i]*=x
                polynomial[i]%=mod
        else:
            for bit in range(n,0,-1):
                a=1<<bit-1
                x=pow(prim_root_inve,mod-1>>bit,mod)
                U=[1]
                for _ in range(a):
                    U.append(U[-1]*x%mod)
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t])%mod,U[j]*(polynomial[s]-polynomial[t])%mod

    l=len(polynomial0)+len(polynomial1)-1
    n=(len(polynomial0)+len(polynomial1)-2).bit_length()
    polynomial0=polynomial0+[0]*((1<<n)-len(polynomial0))
    polynomial1=polynomial1+[0]*((1<<n)-len(polynomial1))
    DFT(polynomial0,n)
    DFT(polynomial1,n)
    ntt=[x*y%mod for x,y in zip(polynomial0,polynomial1)]
    DFT(ntt,n,inverse=True)
    ntt=ntt[:l]
    return ntt

def NTT_Pow(polynomial,N):
    if N==0:
        return [1]
    if N==1:
        return [x for x in polynomial]
    if mod==998244353:
        prim_root=3
        prim_root_inve=332748118
    else:
        prim_root=Primitive_Root(mod)
        prim_root_inve=MOD(mod).Pow(prim_root,-1)
    def DFT(polynomial,n,inverse=False):
        if inverse:
            for bit in range(1,n+1):
                a=1<<bit-1
                x=pow(prim_root,mod-1>>bit,mod)
                U=[1]
                for _ in range(a):
                    U.append(U[-1]*x%mod)
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t]*U[j])%mod,(polynomial[s]-polynomial[t]*U[j])%mod
            x=pow((mod+1)//2,n,mod)
            for i in range(1<<n):
                polynomial[i]*=x
                polynomial[i]%=mod
        else:
            for bit in range(n,0,-1):
                a=1<<bit-1
                x=pow(prim_root_inve,mod-1>>bit,mod)
                U=[1]
                for _ in range(a):
                    U.append(U[-1]*x%mod)
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t])%mod,U[j]*(polynomial[s]-polynomial[t])%mod
    n=((len(polynomial)-1)*N).bit_length()
    ntt=polynomial+[0]*((1<<n)-len(polynomial))
    DFT(ntt,n)
    for i in range(1<<n):
        ntt[i]=pow(ntt[i],N,mod)
    DFT(ntt,n,inverse=True)
    ntt=ntt[:(len(polynomial)-1)*N+1]
    return ntt

def Extended_Euclid(n,m):
    stack=[]
    while m:
        stack.append((n,m))
        n,m=m,n%m
    if n>=0:
        x,y=1,0
    else:
        x,y=-1,0
    for i in range(len(stack)-1,-1,-1):
        n,m=stack[i]
        x,y=y,x-(n//m)*y
    return x,y

class MOD:
    def __init__(self,p,e=None):
        self.p=p
        self.e=e
        if self.e==None:
            mod=self.p
        else:
            mod=self.p**self.e

    def Pow(self,a,n):
        a%=mod
        if n>=0:
            return pow(a,n,mod)
        else:
            assert math.gcd(a,mod)==1
            x=Extended_Euclid(a,mod)[0]
            return pow(x,-n,mod)

    def Build_Fact(self,N):
        assert N>=0
        self.factorial=[1]
        if self.e==None:
            for i in range(1,N+1):
                self.factorial.append(self.factorial[-1]*i%mod)
        else:
            self.cnt=[0]*(N+1)
            for i in range(1,N+1):
                self.cnt[i]=self.cnt[i-1]
                ii=i
                while ii%self.p==0:
                    ii//=self.p
                    self.cnt[i]+=1
                self.factorial.append(self.factorial[-1]*ii%mod)
        self.factorial_inve=[None]*(N+1)
        self.factorial_inve[-1]=self.Pow(self.factorial[-1],-1)
        for i in range(N-1,-1,-1):
            ii=i+1
            while ii%self.p==0:
                ii//=self.p
            self.factorial_inve[i]=(self.factorial_inve[i+1]*ii)%mod

    def Fact(self,N):
        if N<0:
            return 0
        retu=self.factorial[N]
        if self.e!=None and self.cnt[N]:
            retu*=pow(self.p,self.cnt[N],mod)%mod
            retu%=mod
        return retu

    def Fact_Inve(self,N):
        if self.e!=None and self.cnt[N]:
            return None
        return self.factorial_inve[N]

    def Comb(self,N,K,divisible_count=False):
        if K<0 or K>N:
            return 0
        retu=self.factorial[N]*self.factorial_inve[K]%mod*self.factorial_inve[N-K]%mod
        if self.e!=None:
            cnt=self.cnt[N]-self.cnt[N-K]-self.cnt[K]
            if divisible_count:
                return retu,cnt
            else:
                retu*=pow(self.p,cnt,mod)
                retu%=mod
        return retu

def Tonelli_Shanks(N,p):
    if pow(N,p>>1,p)==p-1:
        retu=None
    elif p%4==3:
        retu=pow(N,(p+1)//4,p)
    else:
        for nonresidue in range(1,p):
            if pow(nonresidue,p>>1,p)==p-1:
                break
        pp=p-1
        cnt=0
        while pp%2==0:
            pp//=2
            cnt+=1
        s=pow(N,pp,p)
        retu=pow(N,(pp+1)//2,p)
        for i in range(cnt-2,-1,-1):
            if pow(s,1<<i,p)==p-1:
                s*=pow(nonresidue,p>>1+i,p)
                s%=p
                retu*=pow(nonresidue,p>>2+i,p)
                retu%=p
    return retu

def FFT(polynomial0,polynomial1,digit=10**5):
    def DFT(polynomial,n,inverse=False):
        if inverse:
            primitive_root=[math.cos(-i*2*math.pi/(1<<n))+math.sin(-i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
        else:
            primitive_root=[math.cos(i*2*math.pi/(1<<n))+math.sin(i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
        if inverse:
            for bit in range(1,n+1):
                a=1<<bit-1
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=polynomial[s]+polynomial[t]*primitive_root[j<<n-bit],polynomial[s]-polynomial[t]*primitive_root[j<<n-bit]
        else:
            for bit in range(n,0,-1):
                a=1<<bit-1
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=polynomial[s]+polynomial[t],primitive_root[j<<n-bit]*(polynomial[s]-polynomial[t])

    def FFT_(polynomial0,polynomial1):
        N0=len(polynomial0)
        N1=len(polynomial1)
        N=N0+N1-1
        n=(N-1).bit_length()
        polynomial0=polynomial0+[0]*((1<<n)-N0)
        polynomial1=polynomial1+[0]*((1<<n)-N1)
        DFT(polynomial0,n)
        DFT(polynomial1,n)
        fft=[x*y for x,y in zip(polynomial0,polynomial1)]
        DFT(fft,n,inverse=True)
        fft=[round((fft[i]/(1<<n)).real) for i in range(N)]
        return fft

    N0=len(polynomial0)
    N1=len(polynomial1)
    N=N0+N1-1
    polynomial00,polynomial01=[None]*N0,[None]*N0
    polynomial10,polynomial11=[None]*N1,[None]*N1
    for i in range(N0):
        polynomial00[i],polynomial01[i]=divmod(polynomial0[i],digit)
    for i in range(N1):
        polynomial10[i],polynomial11[i]=divmod(polynomial1[i],digit)
    polynomial=[0]*(N)
    a=digit**2-digit
    for i,x in enumerate(FFT_(polynomial00,polynomial10)):
        polynomial[i]+=x*a
    a=digit-1
    for i,x in enumerate(FFT_(polynomial01,polynomial11)):
        polynomial[i]-=x*a
    for i,x in enumerate(FFT_([x1+x2 for x1,x2 in zip(polynomial00,polynomial01)],[x1+x2 for x1,x2 in zip(polynomial10,polynomial11)])):
        polynomial[i]+=x*digit
    return polynomial

def FFT_Pow(polynomial,N):
    if N==0:
        return [1]
    if N==1:
        return [x for x in polynomial]
    def DFT(polynomial,n,inverse=False):
        if inverse:
            primitive_root=[math.cos(-i*2*math.pi/(1<<n))+math.sin(-i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
        else:
            primitive_root=[math.cos(i*2*math.pi/(1<<n))+math.sin(i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
        if inverse:
            for bit in range(1,n+1):
                a=1<<bit-1
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=polynomial[s]+polynomial[t]*primitive_root[j<<n-bit],polynomial[s]-polynomial[t]*primitive_root[j<<n-bit]
        else:
            for bit in range(n,0,-1):
                a=1<<bit-1
                for i in range(1<<n-bit):
                    for j in range(a):
                        s=i*2*a+j
                        t=s+a
                        polynomial[s],polynomial[t]=polynomial[s]+polynomial[t],primitive_root[j<<n-bit]*(polynomial[s]-polynomial[t])
    
    n=((len(polynomial)-1)*N).bit_length()
    fft=polynomial+[0]*((1<<n)-len(polynomial))
    DFT(fft,n)
    for i in range(1<<n):
        fft[i]=pow(fft[i],N)
    DFT(fft,n,inverse=True)
    fft=[round((fft[i]/(1<<n)).real) for i in range((len(polynomial)-1)*N+1)]
    return fft

class Polynomial:
    def __init__(self,polynomial,max_degree=-1,eps=0,mod=0):
        self.max_degree=max_degree
        if self.max_degree!=-1 and len(polynomial)>self.max_degree+1:
            self.polynomial=polynomial[:self.max_degree+1]
        else:
            self.polynomial=polynomial
        mod=mod
        self.eps=eps

    def __eq__(self,other):
        if type(other)!=Polynomial:
            return False
        if len(self.polynomial)!=len(other.polynomial):
            return False
        for i in range(len(self.polynomial)):
            if self.eps<abs(self.polynomial[i]-other.polynomial[i]):
                return False
        return True

    def __ne__(self,other):
        if type(other)!=Polynomial:
            return True
        if len(self.polynomial)!=len(other.polynomial):
            return True
        for i in range(len(self.polynomial)):
            if self.eps<abs(self.polynomial[i]-other.polynomial[i]):
                return True
        return False

    def __add__(self,other):
        if type(other)==Polynomial:
            summ=[0]*max(len(self.polynomial),len(other.polynomial))
            for i in range(len(self.polynomial)):
                summ[i]+=self.polynomial[i]
            for i in range(len(other.polynomial)):
                summ[i]+=other.polynomial[i]
            if mod:
                for i in range(len(summ)):
                    summ[i]%=mod
        else:
            summ=[x for x in self.polynomial] if self.polynomial else [0]
            summ[0]+=other
            if mod:
                summ[0]%=mod
        while summ and abs(summ[-1])<=self.eps:
            summ.pop()
        summ=Polynomial(summ,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return summ

    def __sub__(self,other):
        if type(other)==Polynomial:
            diff=[0]*max(len(self.polynomial),len(other.polynomial))
            for i in range(len(self.polynomial)):
                diff[i]+=self.polynomial[i]
            for i in range(len(other.polynomial)):
                diff[i]-=other.polynomial[i]
            if mod:
                for i in range(len(diff)):
                    diff[i]%=mod
        else:
            diff=[x for x in self.polynomial] if self.polynomial else [0]
            diff[0]-=other
            if mod:
                diff[0]%=mod
        while diff and abs(diff[-1])<=self.eps:
            diff.pop()
        diff=Polynomial(diff,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return diff

    def __mul__(self,other):
        if type(other)==Polynomial:
            if self.max_degree==-1:
                prod=[0]*(len(self.polynomial)+len(other.polynomial)-1)
                for i in range(len(self.polynomial)):
                    for j in range(len(other.polynomial)):
                        prod[i+j]+=self.polynomial[i]*other.polynomial[j]
            else:
                prod=[0]*min(len(self.polynomial)+len(other.polynomial)-1,self.max_degree+1)
                for i in range(len(self.polynomial)):
                    for j in range(min(len(other.polynomial),self.max_degree+1-i)):
                        prod[i+j]+=self.polynomial[i]*other.polynomial[j]
            if mod:
                for i in range(len(prod)):
                    prod[i]%=mod
        else:
            if mod:
                prod=[x*other%mod for x in self.polynomial]
            else:
                prod=[x*other for x in self.polynomial]
        while prod and abs(prod[-1])<=self.eps:
            prod.pop()
        prod=Polynomial(prod,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return prod

    def __matmul__(self,other):
        assert type(other)==Polynomial
        if mod:
            prod=NTT(self.polynomial,other.polynomial)
        else:
            prod=FFT(self.polynomial,other.polynomial)
        if self.max_degree!=-1 and len(prod)>self.max_degree+1:
            prod=prod[:self.max_degree+1]
            while prod and abs(prod[-1])<=self.eps:
                prod.pop()
        prod=Polynomial(prod,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return prod

    def __pow__(self,other):
        if other==0:
            prod=Polynomial([1],max_degree=self.max_degree,eps=self.eps,mod=mod)
        elif other==1:
            prod=Polynomial([x for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
        else:
            prod=[1]
            doub=self.polynomial
            if mod:
                convolve=NTT
                convolve_Pow=NTT_Pow
            else:
                convolve=FFT
                convolve_Pow=FFT_Pow
            while other>=2:
                if other&1:
                    prod=convolve(prod,doub)
                    if self.max_degree!=-1:
                        prod=prod[:self.max_degree+1]
                doub=convolve_Pow(doub,2)
                if self.max_degree!=-1:
                    doub=doub[:self.max_degree+1]
                other>>=1
            prod=convolve(prod,doub)
            if self.max_degree!=-1:
                prod=prod[:self.max_degree+1]
            prod=Polynomial(prod,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return prod

    def __truediv__(self,other):
        if type(other)==Polynomial:
            assert other.polynomial
            for n in range(len(other.polynomial)):
                if self.eps<abs(other.polynomial[n]):
                    break
            assert len(self.polynomial)>n
            for i in range(n):
                assert abs(self.polynomial[i])<=self.eps
            self_polynomial=self.polynomial[n:]
            other_polynomial=other.polynomial[n:]
            if mod:
                inve=MOD(mod).Pow(other_polynomial[0],-1)
            else:
                inve=1/other_polynomial[0]
            quot=[]
            for i in range(len(self_polynomial)-len(other_polynomial)+1):
                if mod:
                    quot.append(self_polynomial[i]*inve%mod)
                else:
                    quot.append(self_polynomial[i]*inve)
                for j in range(len(other_polynomial)):
                    self_polynomial[i+j]-=other_polynomial[j]*quot[-1]
                    if mod:
                        self_polynomial[i+j]%=mod
            for i in range(max(0,len(self_polynomial)-len(other_polynomial)+1),len(self_polynomial)):
                if self.eps<abs(self_polynomial[i]):
                    assert self.max_degree!=-1
                    self_polynomial=self_polynomial[-len(other_polynomial)+1:]+[0]*(len(other_polynomial)-1-len(self_polynomial))
                    while len(quot)<=self.max_degree:
                        self_polynomial.append(0)
                        if mod:
                            quot.append(self_polynomial[0]*inve%mod)
                            self_polynomial=[(self_polynomial[i]-other_polynomial[i]*quot[-1])%mod for i in range(1,len(self_polynomial))]
                        else:
                            quot.append(self_polynomial[0]*inve)
                            self_polynomial=[(self_polynomial[i]-other_polynomial[i]*quot[-1]) for i in range(1,len(self_polynomial))]
                    break
            quot=Polynomial(quot,max_degree=self.max_degree,eps=self.eps,mod=mod)
        else:
            assert self.eps<abs(other)
            if mod:
                inve=MOD(mod).Pow(other,-1)
                quot=Polynomial([x*inve%mod for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
            else:
                quot=Polynomial([x/other for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
        return quot

    def __rtruediv__(self,other):
        assert self.polynomial and self.eps<self.polynomial[0]
        assert self.max_degree!=-1
        if mod:
            quot=[MOD(mod).Pow(self.polynomial[0],-1)]
            if mod==998244353:
                prim_root=3
                prim_root_inve=332748118
            else:
                prim_root=Primitive_Root(mod)
                prim_root_inve=MOD(mod).Pow(prim_root,-1)
            def DFT(polynomial,n,inverse=False):
                polynomial=polynomial+[0]*((1<<n)-len(polynomial))
                if inverse:
                    for bit in range(1,n+1):
                        a=1<<bit-1
                        x=pow(prim_root,mod-1>>bit,mod)
                        U=[1]
                        for _ in range(a):
                            U.append(U[-1]*x%mod)
                        for i in range(1<<n-bit):
                            for j in range(a):
                                s=i*2*a+j
                                t=s+a
                                polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t]*U[j])%mod,(polynomial[s]-polynomial[t]*U[j])%mod
                    x=pow((mod+1)//2,n,mod)
                    for i in range(1<<n):
                        polynomial[i]*=x
                        polynomial[i]%=mod
                else:
                    for bit in range(n,0,-1):
                        a=1<<bit-1
                        x=pow(prim_root_inve,mod-1>>bit,mod)
                        U=[1]
                        for _ in range(a):
                            U.append(U[-1]*x%mod)
                        for i in range(1<<n-bit):
                            for j in range(a):
                                s=i*2*a+j
                                t=s+a
                                polynomial[s],polynomial[t]=(polynomial[s]+polynomial[t])%mod,U[j]*(polynomial[s]-polynomial[t])%mod
                return polynomial
        else:
            quot=[1/self.polynomial[0]]
            def DFT(polynomial,n,inverse=False):
                N=len(polynomial)
                if inverse:
                    primitive_root=[math.cos(-i*2*math.pi/(1<<n))+math.sin(-i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
                else:
                    primitive_root=[math.cos(i*2*math.pi/(1<<n))+math.sin(i*2*math.pi/(1<<n))*1j for i in range(1<<n)]
                polynomial=polynomial+[0]*((1<<n)-N)
                if inverse:
                    for bit in range(1,n+1):
                        a=1<<bit-1
                        for i in range(1<<n-bit):
                            for j in range(a):
                                s=i*2*a+j
                                t=s+a
                                polynomial[s],polynomial[t]=polynomial[s]+polynomial[t]*primitive_root[j<<n-bit],polynomial[s]-polynomial[t]*primitive_root[j<<n-bit]
                    for i in range(1<<n):
                        polynomial[i]=round((polynomial[i]/(1<<n)).real)
                else:
                    for bit in range(n,0,-1):
                        a=1<<bit-1
                        for i in range(1<<n-bit):
                            for j in range(a):
                                s=i*2*a+j
                                t=s+a
                                polynomial[s],polynomial[t]=polynomial[s]+polynomial[t],primitive_root[j<<n-bit]*(polynomial[s]-polynomial[t])

                return polynomial
        for n in range(self.max_degree.bit_length()):
            prev=quot
            if mod:
                quot=DFT([x*y%mod*y%mod for x,y in zip(DFT(self.polynomial[:1<<n+1],n+2),DFT(prev,n+2))],n+2,inverse=True)[:1<<n+1]
            else:
                quot=DFT([x*y*y for x,y in zip(DFT(self.polynomial[:1<<n+1],n+2),DFT(prev,n+2))],n+2,inverse=True)[:1<<n+1]
            for i in range(1<<n):
                quot[i]=2*prev[i]-quot[i]
                if mod:
                    quot[i]%=mod
            for i in range(1<<n,1<<n+1):
                quot[i]=-quot[i]
                if mod:
                    quot[i]%=mod
        quot=quot[:self.max_degree+1]
        if abs(other-1)>self.eps:
            for i in range(len(quot)):
                quot[i]*=other
                if mod:
                    quot[i]%=mod
        quot=Polynomial(quot,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return quot

    def __floordiv__(self,other):
        assert type(other)==Polynomial
        quot=[0]*(len(self.polynomial)-len(other.polynomial)+1)
        rema=[x for x in self.polynomial]
        if mod:
            inve=MOD(mod).Pow(other.polynomial[-1],-1)
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve%mod
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
                    rema[i+j]%=mod
        else:
            inve=1/other.polynomial[-1]
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
        quot=Polynomial(quot,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return quot

    def __mod__(self,other):
        assert type(other)==Polynomial
        quot=[0]*(len(self.polynomial)-len(other.polynomial)+1)
        rema=[x for x in self.polynomial]
        if mod:
            inve=MOD(mod).Pow(other.polynomial[-1],-1)
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve%mod
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
                    rema[i+j]%=mod
        else:
            inve=1/other.polynomial[-1]
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
        while rema and abs(rema[-1])<=self.eps:
            rema.pop()
        rema=Polynomial(rema,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return rema

    def __divmod__(self,other):
        assert type(other)==Polynomial
        quot=[0]*(len(self.polynomial)-len(other.polynomial)+1)
        rema=[x for x in self.polynomial]
        if mod:
            inve=MOD(mod).Pow(other.polynomial[-1],-1)
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve%mod
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
                    rema[i+j]%=mod
        else:
            inve=1/other.polynomial[-1]
            for i in range(len(self.polynomial)-len(other.polynomial),-1,-1):
                quot[i]=rema[i+len(other.polynomial)-1]*inve
                for j in range(len(other.polynomial)):
                    rema[i+j]-=quot[i]*other.polynomial[j]
        while rema and abs(rema[-1])<=self.eps:
            rema.pop()
        quot=Polynomial(quot,max_degree=self.max_degree,eps=self.eps,mod=mod)
        rema=Polynomial(rema,max_degree=self.max_degree,eps=self.eps,mod=mod)
        return quot,rema

    def __neg__(self):
        if mod:
            nega=Polynomial([(-x)%mod for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
        else:
            nega=Polynomial([-x for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
        return nega

    def __pos__(self):
        posi=Polynomial([x for x in self.polynomial],max_degree=self.max_degree,eps=self.eps,mod=mod)
        return posi

    def __bool__(self):
        return self.polynomial

    def __getitem__(self,n):
        if type(n)==int:
            if n<=len(self.polynomial)-1:
                return self.polynomial[n]
            else:
                return 0
        else:
            return Polynomial(polynomial=self.polynomial[n],max_degree=self.max_degree,eps=self.eps,mod=mod)
    
    def __setitem__(self,n,a):
        if mod:
            a%=mod
        if self.max_degree==-1 or n<=self.max_degree:
            if n<=len(self.polynomial)-1:
                self.polynomial[n]=a
            elif self.eps<abs(a):
                self.polynomial+=[0]*(n-len(self.polynomial))+[a]

    def __iter__(self):
        for x in self.polynomial:
            yield x

    def __call__(self,x):
        retu=0
        pow_x=1
        for i in range(len(self.polynomial)):
            retu+=pow_x*self.polynomial[i]
            pow_x*=x
            if mod:
                retu%=mod
                pow_x%=mod
        return retu

    def __str__(self):
        return "["+", ".join(map(str,self.polynomial))+"]"

    def __len__(self):
        return len(self.polynomial)

    def differential(self):
        if mod:
            differential=[x*i%mod for i,x in enumerate(self.polynomial[1:],1)]
        else:
            differential=[x*i for i,x in enumerate(self.polynomial[1:],1)]
        return Polynomial(differential,max_degree=self.max_degree,eps=self.eps,mod=mod)

    def integral(self):
        if mod:
            integral=[0]+[x*MOD(mod).Pow(i+1,-1)%mod for i,x in enumerate(self.polynomial)]
        else:
            integral=[0]+[x/(i+1) for i,x in enumerate(self.polynomial)]
        while integral and abs(integral[-1])<=self.eps:
            integral.pop()
        return Polynomial(integral,max_degree=self.max_degree,eps=self.eps,mod=mod)

    def log(self):
        assert self.max_degree!=-1
        assert self.polynomial and abs(self.polynomial[0]-1)<=self.eps
        log=(1/self)
        if mod:
            log=Polynomial(NTT(self.differential().polynomial,log.polynomial),max_degree=self.max_degree,eps=self.eps,mod=mod)
        else:
            log=Polynomial(FFT(self.differential().polynomial,log.polynomial),max_degree=self.max_degree,eps=self.eps,mod=mod)
        log=log.integral()
        return log

    def Newton(self,n0,f,differentiated_f=None):
        newton=[n0]
        while len(newton)<self.max_degree+1:
            prev=newton
            if differentiated_f==None:
                newton=f(prev,self.polynomial)
            else:
                newton=f(prev)
                for i in range(min(len(self.polynomial),len(newton))):
                    newton[i]-=self.polynomial[i]
                    newton[i]%=mod
                if mod:
                    newton=NTT(newton,(1/Polynomial(differentiated_f(prev),max_degree=len(newton)-1,eps=self.eps,mod=mod)).polynomial)[:len(newton)]
                else:
                    newton=FFT(newton,(1/Polynomial(differentiated_f(prev),max_degree=len(newton)-1,eps=self.eps,mod=mod)).polynomial)[:len(newton)]
            for i in range(len(newton)):
                newton[i]=-newton[i]
                newton[i]%=mod
            for i in range(len(prev)):
                newton[i]+=prev[i]
                newton[i]%=mod
        newton=newton[:self.max_degree+1]
        while newton and newton[-1]<=self.eps:
            newton.pop()
        return Polynomial(newton,max_degree=self.max_degree,eps=self.eps,mod=mod)

    def sqrt(self):
        if self.polynomial:
            for cnt0 in range(len(self.polynomial)):
                if self.polynomial[cnt0]:
                    break
            if cnt0%2:
                sqrt=None
            else:
                if mod:
                    n0=Tonelli_Shanks(self.polynomial[cnt0],mod)
                else:
                    if self.polynomial[cnt0]>=self.eps:
                        n0=self.polynomial[cnt0]**.5
                if n0==None:
                    sqrt=None
                else:
                    def f(prev):
                        if mod:
                            return NTT_Pow(prev,2)+[0]
                        else:
                            return FFT_Pow(prev,2)+[0]
                    def differentiated_f(prev):
                        retu=[0]*(2*len(prev)-1)
                        for i in range(len(prev)):
                            retu[i]+=2*prev[i]
                            if mod:
                                retu[i]%mod
                        return retu
                    sqrt=[0]*(cnt0//2)+Polynomial(self.polynomial[cnt0:],max_degree=self.max_degree-cnt0//2,mod=mod).Newton(n0,f,differentiated_f).polynomial
                    sqrt=Polynomial(sqrt,max_degree=self.max_degree,eps=self.eps,mod=mod)
        else:
            sqrt=Polynomial([],max_degree=self.max_degree,eps=self.eps,mod=mod)
        return sqrt

    def exp(self):
        assert not self.polynomial or abs(self.polynomial[0])<=self.eps
        def f(prev,poly):
            newton=Polynomial(prev,max_degree=2*len(prev)-1,eps=self.eps,mod=mod).log().polynomial
            newton+=[0]*(2*len(prev)-len(newton))
            for i in range(min(len(poly),len(newton))):
                newton[i]-=poly[i]
            if mod:
                for i in range(len(newton)):
                    newton[i]%=mod
            if mod:
                return NTT(prev,newton)[:2*len(prev)]
            else:
                return FFT(prev,newton)[:2*len(prev)]
        return Polynomial(self.polynomial,max_degree=self.max_degree,mod=mod).Newton(1,f)

    def Degree(self):
        return len(self.polynomial)-1

N,M,K=map(int,readline().split())
mod=998244353
MD=MOD(mod)
MD.Build_Fact(N)
poly=[None]*(N-K+1)
for i in range(N-K+1):
    poly[i]=MD.Fact_Inve(i+1)
P=Polynomial(poly,max_degree=N,mod=mod)
P=P.log()
for i in range(N-K+1):
    P[i]*=K
    P[i]%=mod
P=P.exp()
ans=0
for n in range(K,N+1):
    ans+=P[n-K]*MD.Pow(M,N-n)%mod*MD.Fact_Inve(N-n)%mod
ans*=MD.Comb(M,K)*MD.Fact(N)%mod
ans%=mod
print(ans)
0