MOD = 998244353

fact_range = 10**6
fact = [1] * (fact_range + 1)
for i in range(0, fact_range):
    fact[i+1] = fact[i] * (i + 1) % MOD

ifact = [1] * (fact_range + 1)
ifact[fact_range] = pow(fact[fact_range], MOD - 2, MOD)
for i in range(fact_range, 0, -1):
    ifact[i-1] = ifact[i] * i % MOD

def comb(n, k):
    if k < 0 or n < k:
        return 0
    else:
        return fact[n] * ifact[n-k] % MOD * ifact[k] % MOD

n, m = map(int, input().split())
ans = 0
x = (1 + m) * pow(2, MOD - 2, MOD) % MOD
for d in range(m):
    ans += d * (m - d) * (pow(d + 1, n, MOD) - 2 * pow(d, n, MOD) + (pow(d - 1, n, MOD) if d - 1 >= 0 else 0)) * x * n
    ans %= MOD
print(ans)