MOD = 10**9 + 7 def main(): import sys sys.setrecursionlimit(1 << 25) N, *rest = list(map(int, sys.stdin.read().split())) P = rest[:N] visited = [False] * (N + 1) # 1-based cycle_counts = {} for i in range(1, N + 1): if not visited[i]: length = 0 j = i while not visited[j]: visited[j] = True j = P[j-1] length += 1 cycle_counts[length] = cycle_counts.get(length, 0) + 1 max_n = N fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD denominator = 1 for l, cnt in cycle_counts.items(): term = pow(l, cnt, MOD) term = term * fact[cnt] % MOD denominator = denominator * term % MOD numerator = fact[N] inv_denominator = pow(denominator, MOD-2, MOD) answer = numerator * inv_denominator % MOD print(answer) if __name__ == '__main__': main()