n,m=map(int,input().split()) mod=998244353 memo=dict() def dp(m): if m==0:return 1 if m in memo:return memo[m] res=0 for b in range(2): res+=dp((m-b)//2)*pow(n,b,mod) res%=mod memo[m]=res return res print(dp(m))