n,k = map(int,input().split())
M = 64
d0 = [0]*M
d1 = [0]*M
d0[0] = 1
for i in range(M):
    if n>>i&1:
        if k>>i&1:
            for i in range(M)[::-1]:
                d1[i] += d1[i-1] + d0[i]
            d0.insert(0,0)
            d0.pop()
        else:
            for i in range(M-1)[::-1]:
                d0[i] += d0[i-1] + d1[i-1]
    else:
        if k>>i&1:
            d0.pop()
            d0.insert(0,0)
        else:
            for i in range(M-1):
                d0[i+1] += d1[i]
            d1 = [0]*M

print(sum(2**i*v for i,v in enumerate(d0))%998244353)