MOD = 998244353 n, k = map(int, input().split()) numerator = (n * (k - 1)) % MOD denominator = pow(k, n - 1, MOD) inv_denominator = pow(denominator, MOD - 2, MOD) print((numerator * inv_denominator) % MOD)