n,m = map(int,input().split()) nmax = n; mod = 998244353; fa = [1]*(nmax+1); fi = [1]*(nmax+1) for i in range(1,nmax): fa[i+1] = fa[i]*(i+1)%mod fi[nmax] = pow(fa[nmax],mod-2,mod) for i in range(nmax,0,-1): fi[i-1] = fi[i]*i%mod def cmb(n,r): return fa[n]*fi[n-r]%mod*fi[r]%mod if 0<=r<=n else 0 m = abs(m); ans = 0 for i in range(1,int(m**0.5)+1): if i*i>m: break j = m//i if m%i or i%2!=n%2 or j%2!=n%2: continue ans += 2*(1+(i!=j))*cmb(n,(n+i)//2)*cmb(n,(n+j)//2)%mod if m==n%2==0: ans = 2*cmb(n,n//2)*pow(2,n,mod)%mod-cmb(n,n//2)**2%mod print(ans*pow(4,mod-1-n,mod)%mod)