mod = 998244353 n, m = map(int, input().split()) ans = (m - 1) * pow(m, n, mod) for i in range(1, m): ans -= 2 * pow(i, n, mod) print(ans * n * (m + 1) * pow(2, mod - 2, mod) % mod)