n,m=map(int,input().split());M=998244353;F=[1] for i in range(m*2):F+=F[i]*-~i%M, C=lambda A,B:F[A]*pow(F[A-B]*F[B],M-2,M)%M print(sum(C(m-n*k+k,k)for k in range((m//n)+1))%M)