MOD = 998244353 N, K = map(int, input().split()) invK = pow(K, MOD-2, MOD) ans = pow(invK, N-1, MOD) * invK % MOD * K % MOD * (K-1) % MOD * N % MOD print(ans)