N, K = map(int, input().split()) MOD = 998244353 def inverse(n, d): return n * pow(d, -1, MOD) % MOD print(pow(inverse(1, K), N-1, MOD)*(K-1)*N%MOD)