N,M=map(int,input().split()) mod=998244353 dp=[0]*(M+1) dp[0]=1 for m in range(1,M+1): dp[m]+=dp[m-1] if N<=m: dp[m]+=dp[m-N] dp[m]%=mod ans=dp[M] print(ans)