MOD = 998244353 def main(): import sys N, K = map(int, sys.stdin.readline().split()) # Calculate d = popcount(N & ~K) d_bits = N & (~K) d = bin(d_bits).count('1') # Calculate m = popcount(N & K) m_bits = N & K m = bin(m_bits).count('1') # Compute 3^d mod MOD pow3 = pow(3, d, MOD) # Compute 4^m mod MOD pow4 = pow(4, m, MOD) # Result is (pow3 * pow4) % MOD result = (pow3 * pow4) % MOD print(result) if __name__ == '__main__': main()