def solve(): mod = 998244353 N, M = map(int, input().split()) if N==1: return 1 dp = [0]*(M+1) dp[0] = 1 for i in range(1,M+1): dp[i] = dp[i-1] if i>=N: dp[i] += dp[i-N] dp[i] %= mod ans = dp[-1] return ans print(solve())