n, m = map(int, input().split()) res, mod = 0, 998244353 pow_table = [] for i in range(m + 1): pow_table.append(pow(i, n, mod)) for r in range(1, m + 1): res += (pow_table[r] - pow_table[r - 1]) * r res %= mod for l in range(1, m + 1): res -= (pow_table[m - l + 1] - pow_table[m - l]) * l res %= mod res *= n res %= mod res *= m + 1 res %= mod res *= 499122177 res %= mod print(res)