mod = 998244353

n, k = map(int, input().split())
ans = 0
for x in range(1, k + 1):
    ans += pow(k, n, mod) - pow(x - 1, n, mod) - pow(x - 1, n - 1, mod) * (k - x + 1) * n
    ans %= mod
print(ans)