import sys input=lambda:sys.stdin.readline().rstrip() def c(a): ans=0 while a>0: ans+=a%2 a//=2 return ans mod=998244353 R=40 N,M=map(int,input().split()) count=[0 for i in range(R)] count[c(M)]+=1 C=[[1 for i in range(R)] for j in range(R)] for i in range(1,R): for j in range(1,i): C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod for i in range(R): if M%2: for j in range(i+1): count[c(M)-1+j]+=C[i][j] M//=2 ans=0 temp=1 for i in range(R): ans=(ans+temp*count[i])%mod temp=(temp*N)%mod print(ans)