MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() idx = 0 N = int(data[idx]) idx +=1 X = int(data[idx]) idx +=1 Y = int(data[idx]) idx +=1 A = list(map(int, data[idx:idx+X])) idx += X B = list(map(int, data[idx:idx+Y])) idx += Y # Precompute counts for each bit in A and B bits = 18 cntA = [0]*bits cntB = [0]*bits for a in A: for k in range(bits): if (a >> k) & 1: cntA[k] += 1 for b in B: for k in range(bits): if (b >> k) & 1: cntB[k] += 1 invX = pow(X, MOD-2, MOD) invY = pow(Y, MOD-2, MOD) pow_XY = (X % MOD) * (Y % MOD) % MOD pow_XY_N = pow(pow_XY, N, MOD) sum_terms = 0 for k in range(bits): ca = cntA[k] cb = cntB[k] pA = (ca * invX) % MOD pB = (cb * invY) % MOD part = (1 - pA) % MOD D = (pB * part) % MOD denominator = (1 - D) % MOD if denominator != 0: x_num = (pA * pB) % MOD x = x_num * pow(denominator, MOD-2, MOD) % MOD Dn = pow(D, N, MOD) term = (x * (1 - Dn)) % MOD else: term = (N % MOD) * ((pA * pB) % MOD) % MOD pow2k = pow(2, k, MOD) term_total = (term * pow2k) % MOD sum_terms = (sum_terms + term_total) % MOD answer = (sum_terms * pow_XY_N) % MOD print(answer) if __name__ == "__main__": main()