MOD = 998244353 def main(): import sys sys.setrecursionlimit(1 << 25) N, M = map(int, sys.stdin.readline().split()) P = list(map(int, sys.stdin.readline().split())) P = [p-1 for p in P] # 0-based index # Cycle decomposition visited = [False] * N cycles = [] for i in range(N): if not visited[i]: current = i cycle = [] while not visited[current]: visited[current] = True cycle.append(current) current = P[current] cycles.append(cycle) # Check each cycle's min required colors max_min_color = 0 for cycle in cycles: n = len(cycle) min_color = 2 if n % 2 == 0 else 3 if min_color > max_min_color: max_min_color = min_color if M < max_min_color: print(0) return # Precompute factorial and inverse factorial modulo MOD max_n = M fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if k < 0 or k > n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD # Calculate correct_color correct_color = 0 for k in range(0, M+1): product = 1 for cycle in cycles: n = len(cycle) min_color = 2 if n % 2 == 0 else 3 if k < min_color: product = 0 break term1 = pow(k-1, n, MOD) term2 = pow(-1, n, MOD) * (k-1) % MOD term = (term1 + term2) % MOD product = product * term % MOD c = comb(M, k) sign = pow(-1, M - k, MOD) correct_color = (correct_color + sign * c * product) % MOD # Compute answer as correct_color / M! mod MOD inv_m_fact = inv_fact[M] ans = correct_color * inv_m_fact % MOD print(ans) if __name__ == "__main__": main()