import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int ans = 0; int xOne = 0; int xTwo = 0; for (int i = 0; i < n; i++) { int x = sc.nextInt(); if (x == 1) { xOne++; } else if (x == 2) { xTwo++; ans += n; } } int yOne = 0; int yTwo = 0; for (int i = 0; i < n; i++) { int x = sc.nextInt(); if (x == 1) { yOne++; } else if (x == 2) { yTwo++; ans += n - xTwo; } } if (xTwo == 0 && yTwo == 0) { ans += Math.max(xOne, yOne); } else if (xTwo == 0) { ans += yOne; } else if (yTwo == 0) { ans += xOne; } System.out.println(ans); } }