n,m=map(int,input().split());M=998244353;F=[1] for i in range(m):F+=F[i]*-~i%M, def C(A,B): if 0<=B<=A: t=F[A]*pow(F[A-B]*F[B],M-2,M)%M return t else: return 0 print(sum(C(m-n*k+k,k)for k in range(1+(m//n)))%M)