#include <iostream> #include <vector> #include <cmath> #include <map> #include <set> #include <iomanip> #include <queue> #include <algorithm> #include <numeric> #include <deque> using namespace std; using ll = long long; const ll modc=998244353; ll mod_exp(ll b, ll e, ll m){ if (e > 0 && b == 0) return 0; ll ans = 1; b %= m; while (e > 0){ if ((e & 1LL)) ans = (ans * b) % m; e = e >> 1LL; b = (b*b) % m; } return ans; } int main(){ ll K, N, a; cin >> N >> K; a = mod_exp(K, N-1, modc); cout << (K-1) * N % modc * mod_exp(a, modc-2, modc)% modc << endl; return 0; }