N,M = map(int,input().split()) mbin = list(map(int,list(bin(M)[2:]))) cnt = 0 n = len(mbin) ans = 0 mod = 998244353 for i in range(n): if mbin[i] == 1: ans += pow(N,cnt,mod) * pow(N+1,n-1-i,mod) ans %= mod cnt += 1 ans += pow(N,cnt,mod) ans %= mod print(ans)