n, m = map(int, input().split()) mod = 998244353 a = [0]*30 def count(n, k): x = 1; c = [1] for i in range(n): x = x*(n-i)//(i+1); c.append(x) for i, v in enumerate(c): a[i+k] += v s = 0 for i, v in enumerate(bin(m)[2:]): if v == "1": count(len(bin(m))-3-i, s); s += 1 a[s] += 1 ans = 0 for i, v in enumerate(a): ans = (ans+pow(n, i, mod)*v)%mod print(ans)