mod = 998244353 n, m = map(int, input().split()) ans = (m - 1) * pow(m, n, mod) * (m + 1) % mod for i in range(1, m): ans -= pow(i, n, mod) * (i + 1) % mod for i in range(m, 1, -1): ans -= pow(m + 1 - i, n, mod) * (m + i) % mod print(ans * n * pow(2, mod - 2, mod) % mod)