n = int(input()) a = list(map(int, input().split())) b = list(map(int, input().split())) r = 0 ans = 0 bx = 0 sx = 0 ax = 0 for l in range(n): while r < n and bx ^ b[r] == 0 and sx + a[r] == ax ^ a[r]: bx ^= b[r] sx += a[r] ax ^= a[r] r += 1 ans += r - l bx ^= b[l] sx -= a[l] ax ^= a[l] print(ans)