import sys input = lambda: sys.stdin.readline().rstrip() ii = lambda: int(input()) mi = lambda: map(int, input().split()) li = lambda: list(mi()) INF = 2 ** 63 - 1 mod = 998244353 n, x, y = mi() a = li() b = li() ans = 0 for bit in range(0, 20): dp = [[0] * 2 for _ in range(n + 1)] dp[0][0] = 1 aone = bone = 0 for i in range(x): if 1 & (a[i] >> bit): aone += 1 for j in range(y): if 1 & (b[j] >> bit): bone += 1 azero = x - aone bzero = y - bone for i in range(n): dp[i + 1][1] += dp[i][1] * azero * bone dp[i + 1][1] %= mod dp[i + 1][1] += dp[i][1] * aone * bone dp[i + 1][1] %= mod dp[i + 1][1] += dp[i][0] * aone * bone dp[i + 1][1] %= mod dp[i + 1][0] += dp[i][1] * x * bzero dp[i + 1][0] %= mod dp[i + 1][0] += dp[i][0] * aone * bzero dp[i + 1][0] %= mod dp[i + 1][0] += dp[i][0] * azero * y dp[i + 1][0] %= mod ans += dp[n][1] * pow(2, bit, mod) ans %= mod print(ans)