import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int[] arr = new int[n]; for (int i = 0; i < n; i++) { arr[i] = sc.nextInt(); } int[] order = new int[n + 1]; for (int i = 0; i < n; i++) { order[sc.nextInt()] = i; } ArrayList list = new ArrayList<>(); list.add(-1); list.add(Integer.MAX_VALUE); long count = 0; for (int i = 0; i < n; i++) { int x = order[arr[i]]; int left = 0; int right = list.size(); while (right - left > 1) { int m = (left + right) / 2; if (list.get(m) < x) { left = m; } else { right = m; } } count += list.size() - right - 1; list.add(right, x); } System.out.println(count); } }