n,m = map(int,input().split()) mod = 998244353 d = [0]*(m+1) d[0] = 1 if n == 1: d[-1] = 1 else: for i in range(m): d[i+1] += d[i] d[i+1] %= mod if i+n <= m: d[i+n] += d[i] d[i+n] %= mod print(d[-1])