#include <bits/stdc++.h> using namespace std; using ll = long long; using P = pair<ll,ll>; #define fix(x) fixed << setprecision(x) #define asc(x) x, vector<x>, greater<x> #define rep(i, n) for(ll i = 0; i < n; i++) #define all(x) (x).begin(),(x).end() template<class T>bool chmin(T&a, const T&b){if(a>b){a=b;return 1;}return 0;} template<class T>bool chmax(T&a, const T&b){if(a<b){a=b;return 1;}return 0;} constexpr ll INFLL = (1LL << 62), MOD = 998244353; constexpr int INF = (1 << 30); ll mpow(ll x, ll y){ ll res = 1; x %= MOD; while(y){ if(y%2) res = res * x % MOD; x = x * x % MOD; y /= 2; } return res; } int main(){ cin.tie(nullptr); ios::sync_with_stdio(false); ll n,k; cin >> n >> k; cout << n*k*(k-1)%MOD*mpow(mpow(k,n),MOD-2)%MOD << '\n'; return 0; }