def solve(n, k): q = [0] * (k + 10) mod = 998244353 ans = 0 kn = pow(k, n, mod) for i in range(k + 1, 0, -1): q[i] = kn - pow(i - 1, n, mod) - (k - i + 1) * n * pow(i - 1, n - 1, mod) q[i] %= mod ans += i * (q[i] - q[i + 1]) ans %= mod return ans n, k = map(int,input().split()) print(solve(n, k))