mod = 998244353 n, k = map(int, input().split()) ans = 0 for x in range(1, k + 1): ans += pow(k, n, mod) - pow(x - 1, n, mod) - pow(x - 1, n - 1, mod) * (k - x + 1) * n ans %= mod print(ans)