MOD = 998244353 N, K = map(int, input().split()) result = 1 for b in range(61): # Since 2^60 is over 1e18, iterate up to 60 bits n_bit = (N >> b) & 1 k_bit = (K >> b) & 1 if k_bit == 0: contrib = 1 + 2 * n_bit else: contrib = 2 * (1 + n_bit) result = (result * contrib) % MOD print(result)