MOD=998244353 def pow_mod(n,k,m): res=1 while k: if k&1: res=res*n%m n=n*n%m k>>=1 return res n,k=map(int,input().split()) print(n*k*(k-1)*pow_mod(pow_mod(k,n,MOD),MOD-2,MOD)%MOD)