mod=998244353 n,m=map(int,input().split()) if n==1: print(1) exit() dp=[1]*(m+1) for i in range(n,m+1): dp[i]=dp[i-1]+dp[i-n] dp[i]%=mod print(dp[-1])