#include <iostream> #include <string> #include <vector> #include <algorithm> #include <functional> #include <cmath> #include <iomanip> #include <stack> #include <queue> #include <numeric> #include <map> #include <unordered_map> #include <set> #include <fstream> #include <chrono> #include <random> #include <bitset> #include <atcoder/all> #define rep(i,n) for(int i=0;i<(n);i++) #define all(x) x.begin(), x.end() #define rall(x) x.rbegin(), x.rend() #define sz(x) ((int)(x).size()) #define pb push_back using ll = long long; using namespace std; template<class T>bool chmax(T &a, const T &b) { if (a<b) { a=b; return 1; } return 0; } template<class T>bool chmin(T &a, const T &b) { if (b<a) { a=b; return 1; } return 0; } ll gcd(ll a, ll b) {return b?gcd(b,a%b):a;} ll lcm(ll a, ll b) {return a/gcd(a,b)*b;} const ll mod = 998244353; ll mpow(ll a, ll x){ ll res = 1; while(x > 0){ if(x & 1) res = (res * a) % mod; a = (a * a) % mod; x >>= 1; } return res; } ll inv(ll x){ ll res = 1, k = mod - 2; while(k){ if(k&1) res = (res * x) % mod; x = (x * x) % mod; k >>= 1; } return res; } int main(){ ll N,K; cin >> N >> K; ll a = mpow(inv(K),N-1) * ((K-1) * inv(K) % mod) % mod * K % mod * N % mod; cout << a << endl; return 0; };