N, K = map(int, input().split())
mod = 998244353
ans = N * (K - 1)
inv = pow(K, mod - 2, mod)
ans *= pow(inv, N - 1, mod)
print(ans % mod)