#include using namespace std; #ifdef LOCAL #include "debug.hpp" #else #define debug(...) 1 #endif template struct modint { unsigned int v = 0; static constexpr int mod() { return m; } static constexpr unsigned int umod() { return m; } unsigned int val() { return v; } modint() : v(0) {} modint(long long _v) { long long x = (long long)(_v % umod()); if (x < 0) { x += umod(); } v = (unsigned int) x; } modint operator+() const { return *this; } modint operator-() const { return modint() - *this; } modint(const modint &rhs) { v = rhs.v; } modint &operator+=(const modint &rhs) { v += rhs.v; if (v >= umod()) { v -= umod(); } return *this; } modint operator+(const modint &rhs) const { return modint(*this) += rhs; } modint &operator-=(const modint &rhs) { v -= rhs.v; if (v >= umod()) { v += umod(); } return *this; } modint operator-(const modint &rhs) const { return modint(*this) -= rhs; } modint &operator*=(const modint &rhs) { unsigned long long x = v; x *= rhs.v; v = (unsigned int) (x % umod()); return *this; } modint operator*(const modint &rhs) const { return modint(*this) *= rhs; } template modint pow(T n) const { modint x = *this, ret = 1; while (n) { if (n & 1) ret *= x; x *= x; n >>= 1; } return ret; } modint inv() const { return pow(umod() - 2); } modint &operator/=(const modint &rhs) { *this *= rhs.inv(); return *this; } modint operator/(const modint &rhs) const { return modint(*this) /= rhs; } friend istream &operator>>(istream &is, modint &v) { long long x; is >> x; v.v = x; return is; } friend ostream &operator<<(ostream &os, modint &v) { return os << v.v; } }; constexpr int md = 998244353; // constexpr int md = 1000000007; vector> fact, inv, inv_fact; void cominit(int MAX) { fact.resize(MAX + 1); inv.resize(MAX + 1); inv_fact.resize(MAX + 1); fact[0] = fact[1] = 1; inv_fact[0] = inv_fact[1] = 1; inv[1] = 1; for (int i = 2; i <= MAX; i++) { fact[i] = fact[i - 1] * i; inv[i] = -inv[md % i] * (modint) (md / i); inv_fact[i] = inv_fact[i - 1] * inv[i]; } } template modint Com(T n, T k) { assert(n < (int) fact.size() && k < (int) fact.size()); if (n < k) return 0; if (n < 0 || k < 0) return 0; return fact[n] * inv_fact[k] * inv_fact[n - k]; } template modint Per(T n, T k) { assert(n < (int) fact.size() && k < (int) fact.size()); if (n < k) return 0; if (n < 0 || k < 0) return 0; return fact[n] * inv_fact[n - k]; } using Mint = modint; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int n, k; cin >> n >> k; cominit(n); Mint ans = Mint(n) * k * (k - 1) / Mint(k).pow(n); cout << ans << '\n'; }