mod=998244353 n,k=map(int,input().split()) f=[0]*(k+2) for x in range(1,k+1): f[x]=pow((x-1+k-x+1),n,mod)-pow(x-1,n,mod)-n*(k-x+1)*pow(x-1,n-1,mod) f[x]%=mod ans=0 for x in range(1,k+1): ans+=x*(f[x]-f[x+1]) ans%=mod print(ans)