#include using namespace std; #define int long long typedef long long ll; const int p=998244353; int po(int a,int b) {if(b==0) return 1; if(b==1) return a; if(b%2==0) {int u=po(a,b/2);return (u*1LL*u)%p;} else {int u=po(a,b-1);return (a*1LL*u)%p;}} int inv(int x) {return po(x,p-2);} template struct Fft { // 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ... int g[1 << (K - 1)]; Fft() : g() { //if tl constexpr... static_assert(K >= 2, "Fft: K >= 2 must hold"); g[0] = 1; g[1 << (K - 2)] = G; for (int l = 1 << (K - 2); l >= 2; l >>= 1) { g[l >> 1] = (static_cast(g[l]) * g[l]) % M; } assert((static_cast(g[1]) * g[1]) % M == M - 1); for (int l = 2; l <= 1 << (K - 2); l <<= 1) { for (int i = 1; i < l; ++i) { g[l + i] = (static_cast(g[l]) * g[i]) % M; } } } void fft(vector &x) const { const int n = x.size(); assert(!(n & (n - 1)) && n <= 1 << K); for (int h = __builtin_ctz(n); h--; ) { const int l = 1 << h; for (int i = 0; i < n >> 1 >> h; ++i) { for (int j = i << 1 << h; j < ((i << 1) + 1) << h; ++j) { const int t = (static_cast(g[i]) * x[j | l]) % M; if ((x[j | l] = x[j] - t) < 0) x[j | l] += M; if ((x[j] += t) >= M) x[j] -= M; } } } for (int i = 0, j = 0; i < n; ++i) { if (i < j) std::swap(x[i], x[j]); for (int l = n; (l >>= 1) && !((j ^= l) & l); ) {} } } vector convolution(const vector &a, const vector &b) const { if(a.empty() || b.empty()) return {}; const int na = a.size(), nb = b.size(); int n, invN = 1; for (n = 1; n < na + nb - 1; n <<= 1) invN = ((invN & 1) ? (invN + M) : invN) >> 1; vector x(n, 0), y(n, 0); std::copy(a.begin(), a.end(), x.begin()); std::copy(b.begin(), b.end(), y.begin()); fft(x); fft(y); for (int i = 0; i < n; ++i) x[i] = (((static_cast(x[i]) * y[i]) % M) * invN) % M; std::reverse(x.begin() + 1, x.end()); fft(x); x.resize(na + nb - 1); return x; } }; Fft<998244353,23,31> muls; vector form(vector v,int n) { while(v.size()n) v.pop_back(); return v; } vector operator *(vector v1,vector v2) { return muls.convolution(v1,v2); } vector operator +(vector v1,vector v2) { while(v2.size()=p) v1[i]-=p; else if(v1[i]<0) v1[i]+=p;} return v1; } vector operator -(vector v1,vector v2) { int sz=max(v1.size(),v2.size());while(v1.size()=p) v1[i]-=p;} return v1; } vector trmi(vector v) { for(int i=1;i0) v[i]=p-v[i]; else v[i]=(-v[i]);} return v; } vector deriv(vector v) { if(v.empty()) return{}; vector ans(v.size()-1); for(int i=1;i integ(vector v) { vector ans(v.size()+1);ans[0]=0; for(int i=1;i mul(vector > v) { if(v.size()==1) return v[0]; vector > v1,v2;for(int i=0;i inv1(vector v,int n) { assert(v[0]!=0); int sz=1;v=form(v,n);vector a={inv(v[0])}; while(sz vsz;for(int i=0;i b=((vector) {1})-muls.convolution(a,vsz); for(int i=0;i c=muls.convolution(b,a); for(int i=0;i inv(vector v,int n) { v=form(v,n);assert(v[0]!=0);if(v.size()==1) {return {inv(v[0])};} vector v1=trmi(v); vector a=v1*v;a=form(a,2*n); vector b((n+1)/2);for(int i=0;i ans1=inv(b,b.size());vector ans2(n);for(int i=0;i operator/(vector a,vector b) { while(!a.empty() && a.back()==0) a.pop_back(); while(!b.empty() && b.back()==0) b.pop_back(); int n=a.size();int m=b.size();if(n ans=a*inv(b,n-m+1);while(ans.size()>n-m+1) ans.pop_back(); reverse(ans.begin(),ans.end());while(!ans.empty() && ans.back()==0) ans.pop_back();return ans; } vector operator%(vector a,vector b) { vector ans=a-b*(a/b);while(!ans.empty() && ans.back()==0) ans.pop_back(); return ans; } const int maxn=2e5+5; int fact[maxn];int invf[maxn]; int c(int n,int k) { int ans=fact[n];ans*=invf[k];ans%=p;ans*=invf[n-k];ans%=p;return ans; } vector po1(vector a,int b,int sz) { if(b==0) return {1}; if(b==1) return a; if(b%2==0) {vector v=po1(a,b/2,sz);v=(v*v);while(v.size()>sz) v.pop_back(); return v;} else {vector v=po1(a,b-1,sz);v=(a*v);while(v.size()>sz) v.pop_back(); return v;} } int32_t main() { ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0); fact[0]=1;for(int i=1;i>n>>m; if(m<=1) {cout<<0;return 0;} vector h(n+1);for(int i=0;i<=n;++i) {h[i]=(invf[i]*invf[i])%p;} vector o=po1(h,m-2,n+1); int res=0; for(int i=2;i<=n;++i) { res+=(((((((((((c(n,i)*c(n,i))%p)*o[n-i])%p)*fact[n-i])%p)*fact[n-i])%p)*(i-1))%p)*po(2,2*i-3))%p;res%=p; } res*=((m*(m-1)/2)%p); cout<<(res%p+p)%p; return 0; }