n, m = map(int, input().split()) mod = 998244353 if n == 1: print(1 % mod) elif m < n: print(1 % mod) else: dp = [0] * (m + 1) dp[0] = 1 for i in range(1, m + 1): dp[i] = dp[i-1] if i >= n: dp[i] += dp[i - n] dp[i] %= mod print(dp[m] % mod)