n,m=map(int,input().split());M=998244353;F=[1] for i in range(m):F+=F[i]*-~i%M, C=lambda A,B:F[A]*pow(F[A-B]*F[B],M-2,M)%M print([sum(C(m-n*k+k,k)for k in range(1+m//n))%M,1][n<2])