n,k=map(int,input().split())
ans=0
mod=998244353
for i in range(1,k+1):
  tmp=(k-i)*n*(pow(i,n-1,mod)-pow(i-1,n-1,mod)+mod)
  tmp+=pow(i,n,mod)-pow(i-1,n,mod)-n*pow(i-1,n-1,mod)
  ans+=tmp*i
  ans%=mod
ans=(ans+mod)%mod
print(ans)