#include using namespace std; #define all(v) v.begin(),v.end() using ll = long long; using ull = unsigned long long; using vll=vector; using vvll = vector>; using P = pair; using vp=vector>; const ll INF=1ll<<60; ll mod10=1e9+7; ll mod99=998244353; const double PI = acos(-1); #define rep(i,n) for (ll i=0;i=0;--i) #define rep2(i,a,n) for (ll i=a;i=n;--i) vector fact,invfact,inv; ll powmod(ll x,ll n,ll m){ ll res=1; while(n>0){ if(n&1) res*=x; x*=x; res%=m; x%=m; n>>=1; } return res; } int main(){ ll N,K; cin>>N>>K; cout << K*(K-1)%mod99*N%mod99*powmod(powmod(K,N,mod99),mod99-2,mod99)%mod99 << endl; }