mod = 998244353
n, m = map(int, input().split())
if n == 1:
    print(1)
    exit()
dp = [0] * (m + 1)
dp[0] = 1
for i in range(1, m + 1):
    if i < n:
        dp[i] = dp[i - 1]
    else:
        dp[i] = dp[i - 1] + dp[i - n]
    dp[i] %= mod
print(dp[m])