N,K = map(int,input().split()) MOD = 998244353 a = K*(K-1) * N b = pow(pow(K,N,MOD),MOD-2,MOD) print(a*b % MOD)