n,k=map(int,input().split()) ans=0 mod=998244353 for i in range(1,k+1): tmp=(k-i)*n*(pow(i,n-1,mod)-pow(i-1,n-1,mod)+mod) tmp+=pow(i,n,mod)-pow(i-1,n,mod)-n*pow(i-1,n-1,mod) ans+=tmp*i ans%=mod ans=(ans+mod)%mod print(ans)