#include using namespace std; using std::cerr; using std::cin; using std::cout; #if __has_include() #include using mint = atcoder::modint998244353; istream &operator>>(istream &is, mint &a) { long long t; is >> t; a = t; return is; } ostream &operator<<(ostream &os, mint a) { return os << a.val(); } #endif typedef long double ld; #define long long long #define uint unsigned int #define ull unsigned long #define overload3(a, b, c, name, ...) name #define rep3(i, a, b) for (int i = (a); i < (b); i++) #define rep2(i, n) rep3(i, 0, n) #define rep1(n) rep2(__i, n) #define rep(...) overload3(__VA_ARGS__, rep3, rep2, rep1)(__VA_ARGS__) #define per3(i, a, b) for (int i = (b) - 1; i >= (a); i--) #define per2(i, n) per3(i, 0, n) #define per1(n) per2(__i, n) #define per(...) overload3(__VA_ARGS__, per3, per2, per1)(__VA_ARGS__) #define all(a) a.begin(), a.end() #define UNIQUE(a) \ sort(all(a)); \ a.erase(unique(all(a)), a.end()) #define sz(a) (int)a.size() #define vec vector #ifndef DEBUG #define cerr \ if (0) \ cerr #undef assert #define assert(...) void(0) #undef endl #define endl '\n' #endif template ostream &operator<<(ostream &os, vector a) { const int n = a.size(); rep(i, n) { os << a[i]; if (i + 1 != n) os << " "; } return os; } template ostream &operator<<(ostream &os, array a) { rep(i, n) { os << a[i]; if (i + 1 != n) os << " "; } return os; } template istream &operator>>(istream &is, vector &a) { for (T &i : a) is >> i; return is; } template bool chmin(T &x, S y) { if ((T)y < x) { x = (T)y; return true; } return false; } template bool chmax(T &x, S y) { if (x < (T)y) { x = (T)y; return true; } return false; } template void operator++(vector &a) { for (T &i : a) ++i; } template void operator--(vector &a) { for (T &i : a) --i; } template void operator++(vector &a, int) { for (T &i : a) i++; } template void operator--(vector &a, int) { for (T &i : a) i--; } using namespace atcoder; using ll = long; template vector NTT(vector a, vector b) { ll nmod = T::mod(); int n = a.size(); int m = b.size(); vector x1(n); vector y1(m); for (int i = 0; i < n; i++) { ll tmp1, tmp2, tmp3; tmp1 = a[i].val(); x1[i] = tmp1; } for (int i = 0; i < m; i++) { ll tmp1, tmp2, tmp3; tmp1 = b[i].val(); y1[i] = tmp1; } auto z1 = convolution<167772161>(x1, y1); auto z2 = convolution<469762049>(x1, y1); auto z3 = convolution<1224736769>(x1, y1); vector res(n + m - 1); constexpr ll m1 = 167772161; constexpr ll m2 = 469762049; constexpr ll m3 = 1224736769; constexpr ll m1m2 = 104391568; constexpr ll m1m2m3 = 721017874; ll mm12 = m1 * m2 % nmod; for (int i = 0; i < n + m - 1; i++) { int v1 = (z2[i] - z1[i]) * m1m2 % m2; if (v1 < 0) v1 += m2; int v2 = (z3[i] - (z1[i] + v1 * m1) % m3) * m1m2m3 % m3; if (v2 < 0) v2 += m3; res[i] = (z1[i] + v1 * m1 + v2 * mm12); } return res; } template struct FormalPowerSeries : vector { using vector::vector; using F = FormalPowerSeries; F &operator=(const vector &g) { int n = g.size(); int m = (*this).size(); if (m < n) (*this).resize(n); for (int i = 0; i < n; i++) (*this)[i] = g[i]; return (*this); } F &operator=(const F &g) { int n = g.size(); int m = (*this).size(); if (m < n) (*this).resize(n); for (int i = 0; i < n; i++) (*this)[i] = g[i]; return (*this); } F &operator-() { for (int i = 0; i < (*this).size(); i++) (*this)[i] *= -1; return (*this); } F &operator+=(const F &g) { int n = (*this).size(); int m = g.size(); if (n < m) (*this).resize(m); for (int i = 0; i < m; i++) (*this)[i] += g[i]; return (*this); } F &operator+=(const T &r) { if ((*this).size() == 0) (*this).resize(1); (*this)[0] += r; return (*this); } F &operator-=(const F &g) { int n = (*this).size(); int m = g.size(); if (n < m) (*this).resize(m); for (int i = 0; i < m; i++) (*this)[i] -= g[i]; return (*this); } F &operator-=(const T &r) { if ((*this).size() == 0) (*this).resize(1); (*this)[0] -= r; return (*this); } F &operator*=(const F &g) { (*this) = convolution((*this), g); return (*this); } F &operator*=(const T &r) { for (int i = 0; i < (*this).size(); i++) (*this)[i] *= r; return (*this); } F &operator/=(const F &g) { int n = (*this).size(); (*this) = convolution((*this), g.inv()); (*this).resize(n); return (*this); } F &operator/=(const T &r) { r = r.inv(); for (int i = 0; i < (*this).size(); i++) (*this)[i] *= r; return (*this); } F &operator<<=(const int d) { int n = (*this).size(); (*this).insert((*this).begin(), d, 0); (*this).resize(n); return *this; } F &operator>>=(const int d) { int n = (*this).size(); (*this).erase((*this).begin(), (*this).begin() + min(n, d)); (*this).resize(n); return *this; } F operator*(const T &g) const { return F(*this) *= g; } F operator-(const T &g) const { return F(*this) -= g; } F operator+(const T &g) const { return F(*this) += g; } F operator/(const T &g) const { return F(*this) /= g; } F operator*(const F &g) const { return F(*this) *= g; } F operator-(const F &g) const { return F(*this) -= g; } F operator+(const F &g) const { return F(*this) += g; } F operator/(const F &g) const { return F(*this) /= g; } F operator%(const F &g) const { return F(*this) %= g; } F operator<<(const int d) const { return F(*this) <<= d; } F operator>>(const int d) const { return F(*this) >>= d; } F pre(int sz) const { return F(begin(*this), begin(*this) + min((int)this->size(), sz)); } F inv(int deg = -1) const { int n = (*this).size(); if (deg == -1) deg = n; assert(n > 0 && (*this)[0] != T(0)); F g(1); g[0] = (*this)[0].inv(); while (g.size() < deg) { int m = g.size(); F f(begin(*this), begin(*this) + min(n, 2 * m)); F r(g); f.resize(2 * m); r.resize(2 * m); internal::butterfly(f); internal::butterfly(r); for (int i = 0; i < 2 * m; i++) f[i] *= r[i]; internal::butterfly_inv(f); f.erase(f.begin(), f.begin() + m); f.resize(2 * m); internal::butterfly(f); for (int i = 0; i < 2 * m; i++) f[i] *= r[i]; internal::butterfly_inv(f); T in = T(2 * m).inv(); in *= -in; for (int i = 0; i < m; i++) f[i] *= in; g.insert(g.end(), f.begin(), f.begin() + m); } return g.pre(deg); } T eval(const T &a) { T x = 1; T ret = 0; for (int i = 0; i < (*this).size(); i++) { ret += (*this)[i] * x; x *= a; } return ret; } void onemul(const int d, const T c) { int n = (*this).size(); for (int i = n - d - 1; i >= 0; i--) { (*this)[i + d] += (*this)[i] * c; } } void onediv(const int d, const T c) { int n = (*this).size(); for (int i = 0; i < n - d; i++) { (*this)[i + d] -= (*this)[i] * c; } } F diff() const { int n = (*this).size(); F ret(n); for (int i = 1; i < n; i++) ret[i - 1] = (*this)[i] * i; ret[n - 1] = 0; return ret; } F integral() const { int n = (*this).size(), mod = T::mod(); vector inv(n); inv[1] = 1; for (int i = 2; i < n; i++) inv[i] = T(mod) - inv[mod % i] * (mod / i); F ret(n); for (int i = n - 2; i >= 0; i--) ret[i + 1] = (*this)[i] * inv[i + 1]; ret[0] = 0; return ret; } F log(int deg = -1) const { int n = (*this).size(); if (deg == -1) deg = n; assert((*this)[0] == T(1)); return ((*this).diff() * (*this).inv(deg)).pre(deg).integral(); } F exp(int deg = -1) const { int n = (*this).size(); if (deg == -1) deg = n; assert(n == 0 || (*this)[0] == 0); F Inv; Inv.reserve(deg); Inv.push_back(T(0)); Inv.push_back(T(1)); auto inplace_integral = [&](F &f) -> void { const int n = (int)f.size(); int mod = T::mod(); while (Inv.size() <= n) { int i = Inv.size(); Inv.push_back((-Inv[mod % i]) * (mod / i)); } f.insert(begin(f), T(0)); for (int i = 1; i <= n; i++) f[i] *= Inv[i]; }; auto inplace_diff = [](F &f) -> void { if (f.empty()) return; f.erase(begin(f)); T coeff = 1, one = 1; for (int i = 0; i < f.size(); i++) { f[i] *= coeff; coeff++; } }; F b{1, 1 < (int)(*this).size() ? (*this)[1] : 0}, c{1}, z1, z2{1, 1}; for (int m = 2; m <= deg; m <<= 1) { auto y = b; y.resize(2 * m); internal::butterfly(y); z1 = z2; F z(m); for (int i = 0; i < m; i++) z[i] = y[i] * z1[i]; internal::butterfly_inv(z); T si = T(m).inv(); for (int i = 0; i < m; i++) z[i] *= si; fill(begin(z), begin(z) + m / 2, T(0)); internal::butterfly(z); for (int i = 0; i < m; i++) z[i] *= -z1[i]; internal::butterfly_inv(z); for (int i = 0; i < m; i++) z[i] *= si; c.insert(end(c), begin(z) + m / 2, end(z)); z2 = c; z2.resize(2 * m); internal::butterfly(z2); F x(begin((*this)), begin((*this)) + min((*this).size(), m)); x.resize(m); inplace_diff(x); x.push_back(T(0)); internal::butterfly(x); for (int i = 0; i < m; i++) x[i] *= y[i]; internal::butterfly_inv(x); for (int i = 0; i < m; i++) x[i] *= si; x -= b.diff(); x.resize(2 * m); for (int i = 0; i < m - 1; i++) x[m + i] = x[i], x[i] = T(0); internal::butterfly(x); for (int i = 0; i < 2 * m; i++) x[i] *= z2[i]; internal::butterfly_inv(x); T si2 = T(m << 1).inv(); for (int i = 0; i < 2 * m; i++) x[i] *= si2; x.pop_back(); inplace_integral(x); for (int i = m; i < min((*this).size(), 2 * m); i++) x[i] += (*this)[i]; fill(begin(x), begin(x) + m, T(0)); internal::butterfly(x); for (int i = 0; i < 2 * m; i++) x[i] *= y[i]; internal::butterfly_inv(x); for (int i = 0; i < 2 * m; i++) x[i] *= si2; b.insert(end(b), begin(x) + m, end(x)); } return b.pre(deg); } F pow(ll m) { int n = (*this).size(); if (m == 0) { F ret(n); ret[0] = 1; return ret; } int x = 0; while (x < n && (*this)[x] == T(0)) x++; if (x >= (n + m - 1) / m) { F ret(n); return ret; } F f(n - x); T y = (*this)[x]; for (int i = x; i < n; i++) f[i - x] = (*this)[i] / y; f = f.log(); for (int i = 0; i < n - x; i++) f[i] *= m; f = f.exp(); y = y.pow(m); for (int i = 0; i < n - x; i++) f[i] *= y; F ret(n); const ll xm = x * m; for (int i = xm; i < n; i++) ret[i] = f[i - xm]; return ret; } F shift(T c) { int n = (*this).size(); int mod = T::mod(); vector inv(n + 1); inv[1] = 1; for (int i = 2; i <= n; i++) inv[i] = mod - inv[mod % i] * (mod / i); T x = 1; for (int i = 0; i < n; i++) { (*this)[i] *= x; x *= (i + 1); } F g(n); T y = 1; T now = 1; for (int i = 0; i < n; i++) { g[n - i - 1] = now * y; now *= c; y *= inv[i + 1]; } auto tmp = convolution(g, (*this)); T z = 1; for (int i = 0; i < n; i++) { (*this)[i] = tmp[n + i - 1] * z; z *= inv[i + 1]; } return (*this); } }; using fps = FormalPowerSeries; constexpr int INF = 1e6 + 2022; mint fact[INF + 1], finv[INF + 1]; void solve() { vec fact(INF + 1), finv(INF + 1); fact[0] = 1; for (int i = 1; i <= INF; i++) fact[i] = i * fact[i - 1]; finv[INF] = fact[INF].inv(); for (int i = INF; i > 0; i--) finv[i - 1] = i * finv[i]; int n; mint k; cin >> n >> k; k--; vec kpow(n + 1); kpow[0] = 1; for (int i = 1; i <= n; i++) kpow[i] = k * kpow[i - 1]; vec c(n + 1); for (int i = 0; i <= n; i++) c[i] = kpow[i] + (i % 2 ? -k : k); assert(c[1].val() == 0); fps f0(n + 1), f1(n + 1), g0(n + 1); for (int i = 1; i <= n; i++) { f0[i] = c[i] * fact[i - 1] * finv[i]; f1[i] = c[i - 1]; } for (int i = 0; i <= n; i++) g0[i] = mint(n).pow(998244353LL - 1 + n - 1 - i) * i * finv[n - i]; f0 = f0.exp(); f1 *= f0; f0[0] = 0; mint ans = 0; rep(s, n) ans += f0[s] * g0[s] * kpow[n - s - 1] * (n - s); rep(s, n + 1) ans += f1[s] * g0[s] * kpow[n - s]; cout << ans * fact[n] << endl; } int main() { // srand((unsigned)time(NULL)); cin.tie(nullptr); ios::sync_with_stdio(false); cout << fixed << setprecision(20); int t = 1; // cin >> t; while (t--) solve(); }