N, M = map(int, input().split())
mod = 998244353
S = list(bin(M)[2:])
S = list(map(int, S))

K = len(S)
pre = [[0] * 2 for i in range(K + 2)]
pre[0][1] = 1
for s in S:
    dp = [[0] * 2 for i in range(K + 2)]
    for j in range(K + 1):
        if s:
            dp[j + 1][1] += pre[j][1]
            dp[j][0] += pre[j][1]
            dp[j + 1][0] += pre[j][0]
            dp[j][0] += pre[j][0]
        else:
            dp[j][1] += pre[j][1]
            dp[j + 1][0] += pre[j][0]
            dp[j][0] += pre[j][0]
        dp[j][0] %= mod
        dp[j+1][0] %= mod
        dp[j][1] %= mod
        dp[j+1][1] %= mod

    dp, pre = pre, dp

ans = 0
for i in range(K + 1):
    ans += pow(N, i, mod) * sum(pre[i])
    ans %= mod

print(ans)