MOD = 998244353 N, K = map(int, input().split()) ans = N * K * (K - 1) ans %= MOD for _ in range(N): ans *= pow(K, MOD-2, MOD) ans %= MOD print(ans)