MOD = 998244353

n, k = map(int, input().split())
numerator = (n * (k - 1)) % MOD
denominator = pow(k, n - 1, MOD)
inv_denominator = pow(denominator, MOD - 2, MOD)
print((numerator * inv_denominator) % MOD)