N, K = map(int, input().split()) p = [0]*(K+1) # p[i]: second maxがi以下になる個数 MOD = 998244353 for i in range(1, K+1): p[i] = (pow(i, N, MOD)+N*(K-i)*pow(i, N-1, MOD))%MOD ans = 0 for i in range(1, K+1): ans += (p[i]-p[i-1])*i%MOD ans %= MOD print(ans)