MOD = 998244353 N, K = map(int, input().split()) print((K * (K - 1) * N * pow(K, -N, MOD)) % MOD)