n, m = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
mod = 998244353

Axs = 0
for i in range(n):
    Axs ^= A[i]
Bxs = 0
for i in range(m):
    Bxs ^= B[i]

if Axs != Bxs:
    print(0)
    exit()

print(pow(2, 20 * (n - 1) * (m - 1), mod))