n, k = map(int, input().split()) mod = 998244353 numerator = n * (k - 1) % mod denominator = pow(k, n - 1, mod) inv_denominator = pow(denominator, mod - 2, mod) result = numerator * inv_denominator % mod print(result)