#include #include #include "atcoder/modint.hpp" #include "atcoder/convolution.hpp" using namespace std; const int mod = 998244353; using mint = atcoder::modint998244353; const int MX = 1010101; mint fact[MX], finv[MX], inv[MX]; mint com(int n, int k) { if(n < 0 or k < 0 or n < k) return 0; return fact[n] * finv[k] * finv[n-k]; } void init() { fact[0] = fact[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for (int i = 2; i < MX; i++) { fact[i] = fact[i-1] * i; inv[i] = mod - inv[mod%i] * (mod/i); finv[i] = finv[i-1] * inv[i]; } } int main() { int n; long long k; cin >> n >> k; if (k == 1){ cout << (n == 1 ? 1 : 0) << "\n"; return 0; } init(); vector kp(n+1), np(n+1), f(n+1); mint tk = 1, tn = 1; mint ni = mint(n).inv(); mint coef = 1; for (int i = 0; i <= n; i++) { kp[i] = tk; tk *= k-1; np[i] = tn; tn *= ni; f[i] = coef * (i%2 ? -1:1); coef *= inv[i+1]; coef *= k-1+i; } auto g = atcoder::convolution(kp,f); // tree mint ans = 0; for (int s = 0; s < n; s++) ans += g[s] * s * kp[n-s-1] * finv[n-s-1] * np[s]; // cycle vector a(n+1); for (int i = 0; i <= n; i++) a[i] = kp[i] + (k-1) * (i%2 ? -1:1); auto h = atcoder::convolution(a,g); for (int s = 1; s <= n; s++) ans += h[s-1] * s * kp[n-s] * finv[n-s] * np[s]; ans *= fact[n] * mint(n).pow(n-1); cout << ans.val() << "\n"; }