N, M = map(int, input().split()) mod = 998244353 fac = [1] * 31 for i in range(2, 31): fac[i] = fac[i - 1] * i % mod caf = [1] * 31 caf[30] = pow(fac[30], mod - 2, mod) for i in range(30, 2, -1): caf[i - 1] = caf[i] * i % mod def comb(n, k): if n < k or k < 0 or n < 0: return 0 return fac[n] * caf[k] % mod * caf[n - k] % mod M += 1 cnt = [0] * 31 prefix = 0 tot = 0 for e in range(30, -1, -1): if M & (1 << e): tot += 1 << e for i in range(e + 1): cnt[prefix + i] += comb(e, i) prefix += 1 assert sum(cnt) == tot ans = 0 Npow = 1 for i in range(31): ans = (ans + cnt[i] * Npow) % mod Npow = Npow * N % mod print(ans)