MOD = 998244353
n, m = map(int, input().split())
if n == 1:
    print(1)
    exit()
dp = [0] * (m + 1)
dp[0] = 1
for i in range(m):
    dp[i+1] += dp[i]
    dp[i+1] %= MOD
    if i + n <= m:
        dp[i+n] += dp[i]
        dp[i+n] %= MOD
print(dp[m])