N,M = map(int,input().split())
mbin = list(map(int,list(bin(M)[2:])))
cnt = 0
n = len(mbin)
ans = 0
mod = 998244353
for i in range(n):
    if mbin[i] == 1:
        ans += pow(N,cnt,mod) * pow(N+1,n-1-i,mod)
        ans %= mod
        cnt += 1
ans += pow(N,cnt,mod)
ans %= mod
print(ans)