n, m = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) mod = 998244353 Axs = 0 for i in range(n): Axs ^= A[i] Bxs = 0 for i in range(m): Bxs ^= B[i] if Axs != Bxs: print(0) exit() print(pow(2, 20 * (n - 1) * (m - 1), mod))