#include <iostream>
#include <vector>
#include <cmath>
#include <map>
#include <set>
#include <iomanip>
#include <queue>
#include <algorithm>
#include <numeric>
#include <deque>

using namespace std;

using ll = long long;
const ll modc=998244353;
ll mod_exp(ll b, ll e, ll m){
    if (e > 0 && b == 0) return 0;
    ll ans = 1;
    b %= m;

    while (e > 0){
        if ((e & 1LL)) ans = (ans * b) % m;
        e = e >> 1LL;
        b = (b*b) % m;
    }

    return ans;
}

int main(){

    ll K, N, a;
    cin >> N >> K;
    a = mod_exp(K, N-1, modc);

    cout << (K-1) * N % modc * mod_exp(a, modc-2, modc)% modc <<  endl;

    return 0;
}