n, m = map(int, input().split()) res, mod = 0, 998244353 pow_t = [] for i in range(m + 1): pow_t.append(pow(i, n, mod)) for d in range(1, m): res += (pow_t[d + 1] - pow_t[d] * 2 + pow_t[d - 1]) % mod * d * (m - d) res %= mod res *= n res %= mod res *= m + 1 res %= mod res *= 499122177 res %= mod print(res)