#include using namespace std; using ll=long long; long long modpow(long long a, long long n, long long mod) { long long res = 1; while (n > 0) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; } long long modinv(long long a, long long m) { long long b = m, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } u %= m; if (u < 0) u += m; return u; } int main(){ ll n,k; cin>>n>>k; ll mod=998244353; ll ans=n*k; ans%=mod; ans*=(k-1); ans%=mod; ans*=modinv(modpow(k,n,mod),mod); ans%=mod; cout<