N,K=map(int,input().split()) res=0 x=1 mod=998244353 for i in range(30,-1,-1): if(K>>i)&1: res+=pow(N+1,i,mod)*x%mod x*=N res%=mod x%=mod print((res+x)%mod)