n,m=map(int,input().split()) if n==1: exit(print(1)) ans=0 mod=998244353 f = [1] for i in range(m*2): f += [f[-1]*(i+1) % mod] def comb(a, b): return f[a]*pow(f[b], mod-2, mod)*pow(f[a-b], mod-2, mod) % mod for i in range(m+1): if (m-i)%n:continue ans+=comb((m-i)//n+i,i) ans%=mod print(ans)