p=print n,m=map(int,input().split()) if n==1:p(1) elif n>m:p(1) elif n==m:p(2) else: a=[1]*(n-1)+[2] for i in range(m):a.append((a[-1]+a[-n])%998244353) p(a[m-1])