N,K=map(int,input().split()) MOD=998244353 ans=N*K*(K-1)%MOD ans*=pow(K,N*(MOD-2),MOD) ans%=MOD print(ans)