MOD = 998244353
N, K = map(int, input().split())
pattern = K * (K - 1) * N % MOD
all_pattern = pow(K, N, MOD)
ans = pattern * pow(all_pattern, MOD - 2, MOD) % MOD
print(ans)