n, m = map(int, input().split()) res, mod = 0, 998244353 for r in range(1, m + 1): res += (pow(r, n, mod) - pow(r - 1, n, mod)) * r res %= mod for l in range(1, m + 1): res -= (pow(m - l + 1, n, mod) - pow(m - l, n, mod)) * l res %= mod res *= n res %= mod res *= m + 1 res %= mod res *= 499122177 res %= mod print(res)