n,m = map(int,input().split()) mod = 998244353 ans = 0 bef = 0 for i in range(1,m+1): all = i*(i+1)//2*pow(i,n-1,mod)*n%mod ans += i*(all-bef) ans %= mod bef = all bef = 0 for i in range(m,0,-1): all = (m*(m+1)//2-i*(i-1)//2)*pow(m+1-i,n-1,mod)*n%mod ans -= i*(all-bef) ans %= mod bef = all print(ans)