mod=998244353 N,K=map(int,input().split()) ans=N*K*(K-1)%mod K_inv=pow(K,mod-2,mod) print(ans*pow(K_inv,N,mod)%mod)