import sys input = sys.stdin.readline N,M=list(map(int,input().split())) mod=998244353 # x以下でbit_countの個数が配列K from functools import lru_cache @lru_cache(maxsize=None) def calc(x): if x==0: return [1]+[0]*29 A=calc(x//2) B=calc((x-1)//2) K=[0]*30 for i in range(29): K[i]+=A[i] K[i+1]+=B[i] return K X=calc(M) ANS=0 for i in range(30): ANS+=pow(N,i,mod)*X[i] ANS%=mod print(ANS)