N,M=map(int,input().split()) MOD=998244353 ans=0 v=1 K=M.bit_length()-1 for i in range(K,-1,-1): if (M>>i)&1: ans+=pow(N+1,i,MOD)*v%MOD ans%=MOD v*=N v%=MOD ans+=v ans%=MOD print(ans)