N = int(input()) S = list(map(int,input().split())) T = list(map(int,input().split())) s2 = S.count(2) t2 = T.count(2) ssum = 0 tsum = 0 for s in S: if s==1: ssum+=max(1,t2) if s==2: ssum+=N for t in T: if t==1: tsum+=max(1,s2) if t==2: tsum+=N print(max(ssum,tsum))