n,k = map(int,input().split()) M = 64 d0 = [0]*M d1 = [0]*M d0[0] = 1 for i in range(M): if n>>i&1: nd0 = d0[:] nd1 = d1[:] if k>>i&1: nd0.insert(0,0) nd0.pop() for i in range(M-1): nd1[i+1] += d1[i] nd1[i] += d0[i] else: for i in range(M-1): nd0[i+1] += d0[i] + d1[i] d0,d1 = nd0,nd1 else: if k>>i&1: d0.pop() d0.insert(0,0) else: for i in range(M-1): d0[i+1] += d1[i] d1 = [0]*M print(sum(2**i*v for i,v in enumerate(d0))%998244353)