N, K = map(int, input().split()) mod = 998244353 ans = N * (K - 1) inv = pow(K, mod - 2, mod) ans *= pow(inv, N - 1, mod) print(ans % mod)