mod=998244353 n,m=map(int,input().split()) fac=[1]*(m+1) for i in range(1,m+1): fac[i]=fac[i-1]*i%mod ans=fac[m] for i in range(n): a=m//n+(m%n>i) ans=ans*pow(fac[a],mod-2,mod)%mod print(ans)