n, k = map(int, input().split()) mod = 998244353 ans = k * (k - 1) % mod * n % mod ans *= pow(pow(k, n, mod), mod - 2, mod) ans %= mod print(ans)