mod=998244353 n,k=map(int,input().split()) ans=0 for x in range(1,k+1): ans+=x*(k-x)%mod*(pow(x,n-1,mod)-pow(x-1,n-1,mod))%mod*n ans%=mod #print(x*(k-x)%mod*(pow(x,n-1,mod)-pow(x-1,n-1,mod))) for x in range(1,k+1): tmp=pow(x,n,mod)-n*pow(x-1,n-1,mod)-pow(x-1,n,mod) ans+=x*tmp%mod ans%=mod print(ans)