mod = 1000000007 eps = 10**-9 def main(): import sys input = sys.stdin.readline N = int(input()) A = list(map(int, input().split())) B = list(map(int, input().split())) cs_A = [0] * (N + 1) cxor_A = [0] * (N + 1) for i, a in enumerate(A): cs_A[i + 1] = cs_A[i] + a cxor_A[i + 1] = cxor_A[i] ^ a cxor_B = [0] * (N + 1) for i, b in enumerate(B): cxor_B[i + 1] = cxor_B[i] ^ b cs_B = [0] * (N + 1) for i in range(N): cs_B[i + 1] = cs_B[i] + cxor_B[i + 1] ans = 0 for i in range(N): ok = i ng = N mid = (ok + ng) // 2 while ng - ok > 1: c = cs_A[mid + 1] - cs_A[i] x = cxor_A[mid + 1] ^ cxor_A[i] if c == x: ok = mid else: ng = mid mid = (ok + ng) // 2 if B[i] == 0: ans += 1 if ok != i: if cxor_B[i + 1] == 1: ans += cs_B[ok + 1] - cs_B[i + 1] else: ans += (ok - i) - (cs_B[ok + 1] - cs_B[i + 1]) #print(i, ans, ok) print(ans) #print(cs_B) if __name__ == '__main__': main()