#include const long long mod = 998244353; long long power(long long a,long long b){ long long res = 1; while(b){ if(b&1) res = res * a%mod; a = a*a%mod; b /= 2; } return res; } long long NotBigger(long long k,long long n){ long long digit = (1+k)*k/2%mod; n = n%mod; return digit*(n%mod)%mod; } long long inRange(long long l,long long r,long long n){ if(l>r) return 0; long long digit = (l+r)*(r-l+1)/2%mod; long long gap = r-l+1; long long res = digit*power(gap,n-1)%mod; res = res*(n%mod)%mod; return res; } int main(){ long long n,m; scanf("%lld%lld",&n,&m); long long pos = 0, neg = 0; for(int i = 1; i <= m; i++){ long long way = inRange(1,i,n); way = way - inRange(1,i-1,n); if(way < 0) way += mod; //printf("pway = %lld\n",way); pos = (pos + way * (long long)i%mod)%mod; way = inRange(i,m,n); way = way - inRange(i+1,m,n); if(way < 0) way += mod; //printf("nway = %lld\n",way); neg = (neg + way * (long long)i%mod)%mod; } //printf("pos = %lld, neg = %lld\n",pos,neg); long long ans = pos - neg; if(ans < 0) ans += mod; printf("%lld\n",ans); return 0; }