N,K = map(int,input().split()) MOD = 998244353 ans = N*K*(K-1) a = pow(K,N,MOD) ans = (ans*pow(a,MOD-2,MOD))%MOD print(ans)