#include #include #include using u32 = unsigned int; using u64 = unsigned long long int; u32 n, m; std::vector p; /* A = [0 1 1 1] [1 0 1 1] [1 1 0 1] [1 1 1 0] tr(A^k) = tr((1 - I)^k) = \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(1^i) = \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(m^{i-1} 1) = (-1)^k * m + \sum_{i=1}^k \binom{k}{i} (-1)^{k-i} m^i = (m - 1)^k + (-1)^k * (m - 1) */ template struct Mint { u32 n; constexpr Mint(u32 n = 0): n(n) {} constexpr Mint operator-() const { return Mint(n ? MOD - n: 0); } constexpr Mint &operator+=(const Mint &rhs){ n += rhs.n; if(n >= MOD) n -= MOD; return *this; } constexpr Mint &operator-=(const Mint &rhs){ if(rhs.n > n) n += MOD; n -= rhs.n; return *this; } constexpr Mint &operator*=(const Mint &rhs){ n = (u64) n * rhs.n % MOD; return *this; } friend constexpr Mint operator+(const Mint &lhs, const Mint &rhs){ return Mint(lhs) += rhs; } friend constexpr Mint operator-(const Mint &lhs, const Mint &rhs){ return Mint(lhs) -= rhs; } friend constexpr Mint operator*(const Mint &lhs, const Mint &rhs){ return Mint(lhs) *= rhs; } friend constexpr bool operator==(const Mint &lhs, const Mint &rhs){ return lhs.n == rhs.n; } friend constexpr bool operator!=(const Mint &lhs, const Mint &rhs){ return lhs.n != rhs.n; } }; template T mypow(T a, u32 n){ T r = 1; for(; n; n >>= 1){ if(n&1) r *= a; a *= a; } return r; } template Mint inv(Mint a){ return mypow(a, MOD-2); } constexpr u32 mod = 998244353; using mint = Mint; std::vector perm(const std::vector &p){ std::vector q = p; std::vector cycles; for(u32 i = 0; i < q.size(); i++){ if(q[i] == -1U) continue; u32 c = 1; u32 x = q[i]; q[i] = -1; while(q[x] != -1U){ u32 z = q[x]; q[x] = -1; x = z; c++; } cycles.push_back(c); } return cycles; } template std::vector > countv(const std::vector &x){ auto y = x; std::vector > r; std::sort(std::begin(y), std::end(y)); u32 cnt = 1; T prev = y[0]; for(u32 i = 1; i < y.size(); i++){ if(y[i] == prev) cnt++; else { r.emplace_back(prev, cnt); cnt = 1; prev = y[i]; } } r.emplace_back(prev, cnt); return r; } int main(){ scanf("%d%d", &n, &m); for(u32 i = 0; i < n; i++){ int x; scanf("%d", &x); p.push_back(x-1); } std::vector cycles = perm(p); std::vector v(m+1, 1); std::vector primes; std::vector min_factors(m); for(u32 i = 2; i < m; i++){ if(min_factors[i] == 0){ primes.push_back(i); min_factors[i] = i; } for(u32 p: primes){ if(p * i >= m || p > min_factors[i]) break; min_factors[p * i] = p; } } auto cyc = countv(cycles); for(auto [c, k]: cyc){ std::vector powc(m); powc[1] = 1; for(u32 i = 2; i < m; i++){ if(min_factors[i] == i) powc[i] = mypow(mint(i), c); else powc[i] = powc[min_factors[i]] * powc[i / min_factors[i]]; } for(u32 i = 2; i <= m; i++){ if(c&1) v[i] *= mypow(powc[i-1] - mint(i-1), k); // v[i] *= mypow(mint(i-1), c) - mint(i-1); else v[i] *= mypow(powc[i-1] + mint(i-1), k); // v[i] *= mypow(mint(i-1), c) + mint(i-1); } } std::vector fact(m+1); fact[0] = 1; for(u32 i = 1; i <= m; i++) fact[i] = fact[i-1] * i; std::vector ifact(m+1); ifact[m] = inv(fact[m]); for(int i = m-1; i >= 0; i--) ifact[i] = ifact[i+1] * (i+1); mint ans = 0; for(u32 i = 2; i <= m; i++){ if((m-i)&1) ans -= fact[m] * ifact[i] * ifact[m-i] * v[i]; else ans += fact[m] * ifact[i] * ifact[m-i] * v[i]; } printf("%d\n", (ans * ifact[m]).n); return 0; }