MOD = 10**9 + 7 def main(): import sys input = sys.stdin.read().split() n = int(input[0]) p = list(map(int, input[1:n+1])) visited = [False] * (n + 1) cycle_freq = {} for i in range(1, n+1): if not visited[i]: cnt = 0 j = i while not visited[j]: visited[j] = True j = p[j-1] cnt += 1 if cnt not in cycle_freq: cycle_freq[cnt] = 0 cycle_freq[cnt] += 1 has_even = False has_odd_multiple = False for l in cycle_freq: if l % 2 == 0: has_even = True break if cycle_freq[l] >= 2: has_odd_multiple = True break if has_even or has_odd_multiple: fact = [1] * (n + 1) for i in range(1, n+1): fact[i] = fact[i-1] * i % MOD denominator = 1 for l in cycle_freq: m = cycle_freq[l] term = pow(l, m, MOD) term = term * pow(fact[m], 1, MOD) % MOD denominator = denominator * term % MOD inv_denominator = pow(denominator, MOD-2, MOD) K = fact[n] * inv_denominator % MOD print(K % MOD) else: fact = [1] * (n + 1) for i in range(1, n+1): fact[i] = fact[i-1] * i % MOD denominator = 1 for l in cycle_freq: m = cycle_freq[l] term = pow(l, m, MOD) term = term * pow(fact[m], 1, MOD) % MOD denominator = denominator * term % MOD inv_denominator = pow(denominator, MOD-2, MOD) K = fact[n] * inv_denominator % MOD if K % 2 == 0: res = K * pow(2, MOD-2, MOD) % MOD else: res = 1 print(res % MOD) if __name__ == '__main__': main()