#include<bits/stdc++.h>
using namespace std;
using ll=long long;

long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

long long modinv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a / b;
        a -= t * b; swap(a, b);
        u -= t * v; swap(u, v);
    }
    u %= m;
    if (u < 0) u += m;
    return u;
}

int main(){
  ll n,k;
  cin>>n>>k;
  ll mod=998244353;
  ll ans=n*k;
  ans%=mod;
  ans*=(k-1);
  ans%=mod;
  ans*=modinv(modpow(k,n,mod),mod);
  ans%=mod;
  cout<<ans<<endl;
}