MOD = 998244353 c, x = map(int, input().split()) if x == 0: print(pow(c + 1, MOD - 2, MOD)) exit() dp = [0] * (x + 1) dp[0] = 1 for i in range(c + 1, x + 1): ndp = [0] * (x + 1) inv = pow(i, MOD - 2, MOD) for j in range(x): ndp[j] += dp[j] * (j * inv % MOD) ndp[j] %= MOD ndp[j + 1] += dp[j] * ((i - j) * inv % MOD) ndp[j + 1] %= MOD dp = ndp ans = 0 for i, d in enumerate(dp): ans += d * (1 - pow(x + 1 - i, MOD - 2, MOD)) % MOD ans %= MOD print(ans)