#include using namespace std; using ll=long long; ll pow_mod(ll a, ll n, ll mod){ if (n==0) return 1; else if (n%2) return (a*pow_mod(a,n-1,mod)%mod); else{ ll b=pow_mod(a,n/2,mod); return b*b%mod; } } int main(){ ll N, M; cin >> N >> M; ll Mod=998244353; ll X=pow_mod(M,N+1,Mod)*(M+1)%Mod; ll Y=0; ll a,b; for (int l=1; l<=M; l++){ a=pow_mod(l-1,N,Mod)*l%Mod; b=pow_mod(M-l+1,N,Mod)*(M+l)%Mod; Y+=a+b; } ll Z=(X-Y)%Mod; Z+=Mod; Z%=Mod; ll two_inv=pow_mod(2,Mod-2,Mod); ll k=N*two_inv%Mod; Z*=k; Z%=Mod; cout << Z << endl; }