n,k = map(int,input().split()) M = 64 d0 = [0]*M d1 = [0]*M d0[0] = 1 for i in range(M): if n>>i&1: if k>>i&1: for i in range(M)[::-1]: d1[i] += d1[i-1] + d0[i] d0.insert(0,0) d0.pop() else: for i in range(M-1)[::-1]: d0[i] += d0[i-1] + d1[i-1] else: if k>>i&1: d0.pop() d0.insert(0,0) else: for i in range(M-1): d0[i+1] += d1[i] d1 = [0]*M print(sum(2**i*v for i,v in enumerate(d0))%998244353)