N, M = map(int, input().split()) mod = 998244353 dp = [0] * (M+1) dp[0] = 1 for i in range(M): dp[i+1] = (dp[i+1] + dp[i]) % mod if i + N <= M: dp[i+N] = (dp[i+N] + dp[i]) % mod print(dp[M])