import sys
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
INF = 2 ** 63 - 1
mod = 998244353


n, x, y = mi()

a = li()

b = li()

ans = 0
for bit in range(0, 20):
    dp = [[0] * 2 for _ in range(n + 1)]
    dp[0][0] = 1
    aone = bone = 0
    for i in range(x):
        if 1 & (a[i] >> bit):
            aone += 1
    for j in range(y):
        if 1 & (b[j] >> bit):
            bone += 1
    azero = x - aone
    bzero = y - bone
    for i in range(n):
        dp[i + 1][1] += (dp[i][1] * azero % mod) * bone
        dp[i + 1][1] %= mod
        dp[i + 1][1] += (dp[i][1] * aone % mod) * bone
        dp[i + 1][1] %= mod
        dp[i + 1][1] += (dp[i][0] * aone % mod) * bone
        dp[i + 1][1] %= mod
        dp[i + 1][0] += (dp[i][1] * x % mod) * bzero
        dp[i + 1][0] %= mod
        dp[i + 1][0] += (dp[i][0] * aone % mod) * bzero
        dp[i + 1][0] %= mod
        dp[i + 1][0] += (dp[i][0] * azero % mod) * y
        dp[i + 1][0] %= mod
    ans += dp[n][1] * pow(2, bit, mod)
    ans %= mod

print(ans)