N,K = map(int,input().split()) P = 998244353 ans = 0 dat = [0] * (K + 1) for i in range(K): dat[K - i] = pow(K,N,P) - (K - i - 1 + (i + 1) * N) * pow(K - i -1,N - 1,P) S = 0 """ for i in range(K,0,-1): dat[i] -= S dat[i] %= P S += dat[i] S %= P ans += dat[i] * i ans %= P """ for i in range(1,K): dat[i] -= dat[i+1] ans += dat[i] * i ans %= P ans += dat[K] * K ans %= P print(ans)