n,x,y = map(int,input().split()) a = list(map(int,input().split())) b = list(map(int,input().split())) mod = 998244353 cul = lambda s:pow(s,mod-2,mod) ans = 0 for i in range(18): a1,a0,b1,b0 = 0,0,0,0 for j in range(x): if a[j] & (1 << i):a1 += 1 else:a0 += 1 for j in range(y): if b[j] & (1 << i):b1 += 1 else:b0 += 1 dp = [1,0] a0x = cul(x) * a0 a1x = cul(x) * a1 b0y = cul(y) * b0 b1y = cul(y) * b1 for j in range(n): n_dp = [0,0] n_dp[0] += dp[0] * (a0x + a1x*b0y) n_dp[0] += dp[1] * b0y n_dp[1] += dp[0] * a1x*b1y n_dp[1] += dp[1] * b1y n_dp[0] %= mod n_dp[1] %= mod dp = n_dp # print(dp,i) ans += pow(x*y,n,mod) * pow(2,i,mod)*dp[1] ans %= mod print(ans)