MOD = 998244353 n,m = map(int,input().split()) nn = n%MOD if n > 2*MOD: n = n%(MOD-1)+MOD-1 def f(x): return x*(x+1)//2%MOD*pow(x,n-1,MOD)%MOD*nn%MOD def g(x): return (m+x)*(m-x+1)//2%MOD*pow(m-x+1,n-1,MOD)%MOD*nn%MOD ans = 0 for x in range(1,m+1): ans += x*(f(x)-f(x-1)-g(x)+g(x+1)) ans %= MOD print(ans)