/** * @FileName a.cpp * @Author kanpurin * @Created 2021.02.21 15:16:07 **/ #include "bits/stdc++.h" using namespace std; typedef long long ll; template< int MOD > struct mint { public: long long x; mint(long long x = 0) :x((x%MOD+MOD)%MOD) {} mint(std::string &s) { long long z = 0; for (int i = 0; i < s.size(); i++) { z *= 10; z += s[i] - '0'; z %= MOD; } this->x = z; } mint& operator+=(const mint &a) { if ((x += a.x) >= MOD) x -= MOD; return *this; } mint& operator-=(const mint &a) { if ((x += MOD - a.x) >= MOD) x -= MOD; return *this; } mint& operator*=(const mint &a) { (x *= a.x) %= MOD; return *this; } mint& operator/=(const mint &a) { long long n = MOD - 2; mint u = 1, b = a; while (n > 0) { if (n & 1) { u *= b; } b *= b; n >>= 1; } return *this *= u; } mint operator+(const mint &a) const { mint res(*this); return res += a; } mint operator-() const {return mint() -= *this; } mint operator-(const mint &a) const { mint res(*this); return res -= a; } mint operator*(const mint &a) const { mint res(*this); return res *= a; } mint operator/(const mint &a) const { mint res(*this); return res /= a; } friend std::ostream& operator<<(std::ostream &os, const mint &n) { return os << n.x; } friend std::istream &operator>>(std::istream &is, mint &n) { long long x; is >> x; n = mint(x); return is; } bool operator==(const mint &a) const { return this->x == a.x; } bool operator!=(const mint &a) const { return this->x != a.x; } mint pow(long long k) const { mint ret = 1; mint p = this->x; while (k > 0) { if (k & 1) { ret *= p; } p *= p; k >>= 1; } return ret; } }; template < const int MOD , bool any = false> struct FormalPowerSeries { private: using P = FormalPowerSeries< MOD, any >; template < class T, class F = multiplies< T > > T power(T a, long long n, F op = multiplies< T >(), T e = {1}) const { assert(n >= 0); T res = e; while (n) { if (n & 1) res = op(res, a); if (n >>= 1) a = op(a, a); } return res; } template< int _MOD > void ntt(vector< mint < _MOD > >& a, bool inverse) { static vector< mint< _MOD > > dw(30), idw(30); if (dw[0] == 0) { mint< _MOD > root = 2; while (power(root, (_MOD - 1) / 2) == 1) root += 1; for (int i = 0; i < 30; i++) dw[i] = -power(root, (_MOD - 1) >> (i + 2)), idw[i] = mint<_MOD>(1) / dw[i]; } int n = a.size(); assert((n & (n - 1)) == 0); if (not inverse) { for (int m = n; m >>= 1;) { mint< _MOD > w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; i++, j++) { auto x = a[i], y = a[j] * w; if (x.x >= _MOD) x.x -= _MOD; a[i].x = x.x + y.x, a[j].x = x.x + (_MOD - y.x); } w *= dw[__builtin_ctz(++k)]; } } } else { for (int m = 1; m < n; m *= 2) { mint< _MOD > w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; i++, j++) { auto x = a[i], y = a[j]; a[i] = x + y, a[j].x = x.x + (_MOD - y.x), a[j] *= w; } w *= idw[__builtin_ctz(++k)]; } } } auto c = mint<_MOD>(1) / mint< _MOD >(inverse ? n : 1); for (auto&& e : a) e *= c; } template< int _MOD > vector< mint< _MOD > > convolution(vector< mint< _MOD > > l, vector< mint< _MOD > > r) { if (l.empty() || r.empty()) return {}; int n = l.size(), m = r.size(), sz = 1 << __lg(2 * (n + m - 1) - 1); if (min(n, m) < 30) { vector< long long > res(n + m - 1); for (int i = 0; i < n; i++) for (int j = 0; j < m; j++) res[i + j] += (l[i] * r[j]).x; return {begin(res), end(res)}; } bool eq = l == r; l.resize(sz), ntt(l, false); if (eq) r = l; else r.resize(sz), ntt(r, false); for (int i = 0; i < sz; i++) l[i] *= r[i]; ntt(l, true), l.resize(n + m - 1); return l; } P pre(const P &p, int sz) const { P ret; ret.a = vector>(p.a.begin(), p.a.begin() + min((int)p.a.size(), sz)); return ret; } public: vector> a; FormalPowerSeries(int sz = 0) { this->a.resize(sz, 0); } P resize(int k) const { return pre(*this,k); } FormalPowerSeries(std::initializer_list> v) { this->a = v; } size_t size() const { return this->a.size(); } bool operator<(const P& r) const { return this->a.size() < r.a.size(); } bool operator>(const P& r) const { return this->a.size() > r.a.size(); } P operator+(const P &a) const { return P(*this) += a; } P operator+(const long long a) const { return P(*this) += a; } P operator-(const P &a) const { return P(*this) -= a; } P operator*(const P &a) const { return P(*this) *= a; } P operator*(const long long a) const { return P(*this) *= a; } P operator/(const P &a) const { return P(*this) /= a; } P operator/(const mint &a) const { return P(*this) /= a; } P &operator+=(const P &r) { this->a.resize(max(this->a.size(),r.size())); for(int i = 0; i < (int)r.size(); i++) this->a[i] += r.a[i]; return *this; } P &operator+=(const long long v) { if (this->a.size() == 0) this->a.resize(1,(v % MOD + MOD) % MOD); else this->a[0] += v; return *this; } P &operator-=(const P &r) { this->a.resize(max(this->a.size(),r.size())); for(int i = 0; i < (int)r.size(); i++) this->a[i] -= r.a[i]; return *this; } P &operator*=(const P &b) { if (!any) { this->a = convolution(this->a, b.a); return *this; } else { if (this->a.empty() || b.a.empty()) { this->a.clear(); return *this; } int n = this->a.size(), m = b.a.size(); static constexpr int mod0 = 998244353, mod1 = 1300234241, mod2 = 1484783617; using Mint0 = mint< mod0 >; using Mint1 = mint< mod1 >; using Mint2 = mint< mod2 >; vector< Mint0 > l0(n), r0(m); vector< Mint1 > l1(n), r1(m); vector< Mint2 > l2(n), r2(m); for (int i = 0; i < n; i++) l0[i] = this->a[i].x, l1[i] = this->a[i].x, l2[i] = this->a[i].x; for (int j = 0; j < m; j++) r0[j] = b.a[j].x, r1[j] = b.a[j].x, r2[j] = b.a[j].x; l0 = convolution(l0,r0); l1 = convolution(l1,r1); l2 = convolution(l2,r2); this->a.resize(n + m - 1); static const Mint1 im0 = Mint1(1) / Mint1(mod0); static const Mint2 im1 = Mint2(1) / Mint2(mod1), im0m1 = im1 / mod0; static const mint m0 = mod0, m0m1 = m0 * mod1; for (int i = 0; i < n + m - 1; i++) { int y0 = l0[i].x; int y1 = (im0 * (l1[i] - y0)).x; int y2 = (im0m1 * (l2[i] - y0) - im1 * y1).x; this->a[i] = m0m1 * y2 + y0 + m0 * y1; } return *this; } } P &operator*=(const long long v) { for (int i = 0; i < this->a.size(); i++) this->a[i] *= v; return *this; } P &operator/=(const P &a) { *this *= a.inverse(); return *this; } P &operator/=(const mint &v) { for (int i = 0; i < this->size(); i++) { this->a[i] /= v; } return *this; } P inverse(int deg = -1) const { assert(this->a.size() != 0 && this->a[0].x != 0); const int n = (int)this->a.size(); if(deg == -1) deg = n; P ret(1); ret[0] = mint(1) / a[0]; for(int i = 1; i < deg; i <<= 1) { ret = pre((ret + ret - ret * ret * pre(*this,i << 1)),i << 1); } return pre(ret,deg); } P differential() const { const int n = (int) this->a.size(); P ret(max(0, n - 1)); for(int i = 1; i < n; i++) ret[i-1] = this->a[i] * i; return ret; } P integral() const { const int n = (int) this->a.size(); P ret(n + 1); for(int i = 0; i < n; i++) ret[i + 1] = this->a[i] / (i + 1); return ret; } P log(int deg = -1) const { assert(this->a.size() != 0 && this->a[0] == 1); const int n = (int)this->a.size(); if(deg == -1) deg = n; return pre((this->differential() * this->inverse(deg)),deg - 1).integral(); } P exp(int deg = -1) const { if (this->a.size() == 0) return P(0); assert(this->a[0] == 0); const int n = (int)this->a.size(); if(deg == -1) deg = n; P ret(1); ret.a[0] = 1; for(int i = 1; i < deg; i <<= 1) { ret = pre((ret * (pre(*this,i << 1) + 1 - ret.log(i << 1))),i << 1); } return pre(ret,deg); } P pow(long long k, int deg = -1) const { const int n = (int) this->a.size(); if(deg == -1) deg = n; for(int i = 0; i < n; i++) { if(this->a[i].x != 0) { long long rev = (mint(1) / this->a[i]).x; P C = *this * rev; P D(n - i); for(int j = i; j < n; j++) D[j - i] = C[j]; D = (D.log() * k).exp() * power(this->a[i], k).x; P E(deg); if(i * k > deg) return E; auto S = i * k; for(int j = 0; j + S < deg && j < D.size(); j++) E[j + S] = D[j]; return E; } } return *this; } mint< MOD > &operator[](int x) { assert(0 <= x && x < (int)this->a.size()); return a[x]; } friend std::ostream &operator<<(std::ostream &os, const P &p) { os << "[ "; for (int i = 0; i < p.size(); ++i) { os << p.a[i] << " "; } os << "]"; return os; } }; constexpr int MOD = 998244353; class UnionFind { private: vector par; public: UnionFind(int n) { par.resize(n, -1); } int root(int x) { if (par[x] < 0) return x; return par[x] = root(par[x]); } bool unite(int x, int y) { int rx = root(x); int ry = root(y); if (rx == ry) return false; if (size(rx) < size(ry)) swap(rx, ry); par[rx] += par[ry]; par[ry] = rx; return true; } bool same(int x, int y) { int rx = root(x); int ry = root(y); return rx == ry; } int size(int x) { return -par[root(x)]; } }; int main() { int n,m;cin >> n >> m; vector p(n); UnionFind uf(n); for (int i = 0; i < n; i++) { cin >> p[i]; uf.unite(i,p[i]-1); } vector c; for (int i = 0; i < n; i++) { if (uf.root(i) == i) { c.push_back(uf.size(i)); } } FormalPowerSeries S(n+1); S[1] = 1; S = S.exp(); S[0] -= 1; mint fact = 1; for (int i = 2; i <= m; i++) { fact *= i; } S = S.pow(m) / fact; fact = 1; for (int i = 1; i <= n; i++) { fact *= i; S[i] *= fact; } mint ans = 0; FormalPowerSeries a(n+1); a[0] = 1; for (int j = 0; j < c.size(); j++) { FormalPowerSeries b(c[j]+1); b[1] = 1; b[0] = 1; b = b.pow(c[j]); b[1] -= c[j]-1; b[0] -= 1; a *= b; a.resize(n+1); } for (int i = m; i <= n; i++) { if ((n+i) & 1) { ans -= a[i] * S[i]; } else { ans += a[i] * S[i]; } } cout << ans << endl; return 0; }