n,m=map(int,input().split()) mod=998244353 memo=dict() def dp(m,f): if m==0:return f if (m,f) in memo:return memo[m,f] ans=0 for bx in range(2): #一番下の桁がbx next_f=f if next_f==1 and m%2==0 and bx==1:next_f=0 if next_f==0 and m%2==1 and bx==0:next_f=1 ans+=dp(m//2,next_f)*pow(n,bx,mod) ans%=mod memo[m,f]=ans return ans print(dp(m,1))