MOD = 998244353 def main(): N, K = map(int, input().split()) result = 1 for d in range(60): k_bit = (K >> d) & 1 n_bit = (N >> d) & 1 if k_bit == 0: if n_bit == 1: result = result * 3 % MOD else: if n_bit == 0: result = result * 2 % MOD else: result = result * 4 % MOD print(result % MOD) if __name__ == "__main__": main()