n, k = map(int, input().split()) mod = 998244353 def calc(x): num = pow(k, n, mod) num -= pow(x - 1, n, mod) num %= mod num -= ((k - x + 1) * n % mod) * pow(x - 1, n - 1, mod) % mod num %= mod return num ans = 0 for i in range(1, k + 1): num = calc(i) - calc(i + 1) ans += i * num % mod ans %= mod print(ans)