n, k = map(int, input().split()) mod = 998244353 print(pow(k, -(n - 1), mod) * (k - 1) * n % mod)