n,k = map(int,input().split()) mod = 998244353 ans = 0 last = 0 for i in range(1,k+1)[::-1]: ans += i*n*(k-i)*(pow(i,n-1,mod)-pow(i-1,n-1,mod))%mod # print(ans,i) ans += i*(pow(i,n,mod)-pow(i-1,n,mod)-n*pow(i-1,n-1,mod))%mod # print(ans,i) ans %= mod print(ans)