MOD = 998244353 N, K = map(int, input().split()) pattern = K * (K - 1) * N % MOD all_pattern = pow(K, N, MOD) ans = pattern * pow(all_pattern, MOD - 2, MOD) % MOD print(ans)