#include #include #include #include using namespace std; using namespace atcoder; typedef long long ll; using mint = modint998244353; vector g,h; mint fac[200010]; mint pw(mint a,ll x){ mint ret = 1; if(x==-1) return (mint)1/a; while(x){ if(x&1) ret *= a; a *= a; x /= 2; } return ret; } mint f(int n,int m,int k){ if(n - (k - 1)<0) return 0; if(m==k){ if((n - (k - 1))==0) return 1; return 0; } return pw(m - k,2*n - 2*(k - 1)); } int main(){ int i,j,k,n,m; cin >> n >> m; fac[0] = 1; for(i=1;i<=200000;i++) fac[i] = fac[i - 1]*(mint)i; mint ans = pw(m,2*n + 1); g.resize(200001); h.resize(200001); for(j=0;j<=200000;j++) h[j] = (mint)1/fac[j]; for(k=0;k<=n;k++) g[k] = f(n,m,k + 1)*pw(k + 1,k - 1)*pw(2,k)*fac[n]/(fac[n - k]*fac[k]); vector hg = convolution(h,g); for(i=0;i