#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using P = pair<ll,ll>;
#define fix(x) fixed << setprecision(x)
#define asc(x) x, vector<x>, greater<x>
#define rep(i, n) for(ll i = 0; i < n; i++)
#define all(x) (x).begin(),(x).end()
template<class T>bool chmin(T&a, const T&b){if(a>b){a=b;return 1;}return 0;}
template<class T>bool chmax(T&a, const T&b){if(a<b){a=b;return 1;}return 0;}
constexpr ll INFLL = (1LL << 62), MOD = 998244353;
constexpr int INF = (1 << 30);
ll mpow(ll x, ll y){
    ll res = 1; x %= MOD;
    while(y){
        if(y%2) res = res * x % MOD;
        x = x * x % MOD;
        y /= 2;
    }
    return res;
}
int main(){
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    ll n,k;
    cin >> n >> k;
    cout << n*k*(k-1)%MOD*mpow(mpow(k,n),MOD-2)%MOD << '\n';
    return 0;
}