#include using namespace std; using ll = long long; ll const m = 998244353; ll mpow(ll a, ll n) { ll ret = 1; while (n) { if (n & 1) ret = (ret * a) % m; a = (a * a) % m; n >>= 1; } return ret; } int main () { int N, K; cin >> N >> K; ll ans = mpow(K, m - 2); ans = mpow(ans, N); ans = (ans * N) % m; ans = (ans * K) % m; ans = (ans * (K - 1)) % m; cout << ans << endl; }