N,K = map(int,input().split()) P = 998244353 ans = K * (K - 1) * N % P r = pow(K,N,P) r = pow(r,P - 2,P) print(ans * r % P)