#yukicoder393B N,K=map(int,input().split()); MOD=998244353 print((K-1)*N%MOD*pow(pow(K,MOD-2,MOD),N-1,MOD)%MOD)