import java.util.*; class Main { static final long MOD = 1000000007; static int[] inverse(int[] permutation) { int n = permutation.length; int[] result = new int[n]; for (int i = 0; i < n; ++i) { result[permutation[i]] = i; } return result; } static int[] twice(int[] permutation) { int n = permutation.length; int[] result = new int[n]; for (int i = 0; i < n; ++i) { result[i] = permutation[permutation[i]]; } return result; } static final int B = 17; public static void main(String[] args) { Scanner scan = new Scanner(System.in); int n = Integer.parseInt(scan.next()); int[] a = new int[n]; int[] b = new int[n]; for (int i = 0; i < n; ++i) { a[i] = Integer.parseInt(scan.next()) - 1; } for (int i = 0; i < n; ++i) { b[i] = Integer.parseInt(scan.next()) - 1; } int[][] aDouble = new int[B][]; int[][] bDouble = new int[B][]; aDouble[0] = inverse(a); bDouble[0] = inverse(b); for (int i = 1; i < B; ++i) { aDouble[i] = twice(aDouble[i - 1]); bDouble[i] = twice(bDouble[i - 1]); } // System.err.println(Arrays.deepToString(aDouble)); // System.err.println(Arrays.deepToString(bDouble)); long total = 0; for (int i = 0; i < n; ++i) { // (i, i) no gyaku long pos = 0; int ai = i; int bi = i; while (true) { if (pos > 0 && ai == bi) { break; } ai = aDouble[0][ai]; bi = bDouble[0][bi]; pos++; } total = (total + pos * (pos - 1) / 2) % MOD; } System.out.println(total); } }