mod = 998244353 n, k = map(int, input().split()) print((n * (k - 1) * pow(k, mod - n, mod)) % mod)