MOD = 998244353 def solve(): N, K = map(int, input().split()) ans = 1 # Iterate for up to 61 bits (s from 0 to 60) # This range is sufficient since N, K <= 10^18 < 2^60. # Higher bits of N and K will be 0. for s in range(61): n_bit = (N >> s) & 1 k_bit = (K >> s) & 1 factor = 0 if n_bit == 0 and k_bit == 0: factor = 1 elif n_bit == 0 and k_bit == 1: factor = 2 elif n_bit == 1 and k_bit == 0: factor = 3 elif n_bit == 1 and k_bit == 1: # (n_bit == 1 and k_bit == 1) factor = 4 ans = (ans * factor) % MOD print(ans) solve()