//#include #include using namespace std; #include using namespace atcoder; using mint = modint998244353; typedef long long ll; #define all(x) (x).begin(), (x).end() #define rall(x) (x).rbegin(), (x).rend() const int MAX = 1e9; const int MIN = -1*1e9; const ll MAXLL = 1e18; const ll MINLL = -1*1e18; int main() { int N,K; cin >> N >> K; cout << (((mint(1)/K).pow(N))*(K*(K-1))*(N)).val() << endl; return 0; }