N, K = map(int, input().split()) mod = 998244353 ans = 0 KN = pow(K, N, mod) q = [0] * (K + 5) for x in range(K, 0, -1): q[x] = KN - pow(x - 1, N, mod) q[x] -= (K - x + 1) * N * pow(x - 1, N - 1, mod) q[x] %= mod ans += x * (q[x] - q[x+1]) ans %= mod print(ans)