結果
問題 | No.1392 Don't be together |
ユーザー | 🍮かんプリン |
提出日時 | 2021-02-21 16:52:57 |
言語 | C++14 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 108 ms / 2,000 ms |
コード長 | 13,337 bytes |
コンパイル時間 | 2,675 ms |
コンパイル使用メモリ | 203,088 KB |
実行使用メモリ | 6,944 KB |
最終ジャッジ日時 | 2024-09-19 14:51:33 |
合計ジャッジ時間 | 4,812 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,812 KB |
testcase_01 | AC | 2 ms
6,812 KB |
testcase_02 | AC | 2 ms
6,820 KB |
testcase_03 | AC | 2 ms
6,940 KB |
testcase_04 | AC | 2 ms
6,940 KB |
testcase_05 | AC | 1 ms
6,944 KB |
testcase_06 | AC | 108 ms
6,940 KB |
testcase_07 | AC | 55 ms
6,940 KB |
testcase_08 | AC | 60 ms
6,944 KB |
testcase_09 | AC | 92 ms
6,944 KB |
testcase_10 | AC | 52 ms
6,940 KB |
testcase_11 | AC | 53 ms
6,940 KB |
testcase_12 | AC | 52 ms
6,944 KB |
testcase_13 | AC | 52 ms
6,940 KB |
testcase_14 | AC | 50 ms
6,940 KB |
testcase_15 | AC | 59 ms
6,944 KB |
testcase_16 | AC | 52 ms
6,944 KB |
testcase_17 | AC | 51 ms
6,940 KB |
testcase_18 | AC | 58 ms
6,940 KB |
testcase_19 | AC | 52 ms
6,944 KB |
testcase_20 | AC | 29 ms
6,940 KB |
testcase_21 | AC | 15 ms
6,944 KB |
testcase_22 | AC | 50 ms
6,940 KB |
testcase_23 | AC | 28 ms
6,944 KB |
testcase_24 | AC | 28 ms
6,940 KB |
testcase_25 | AC | 29 ms
6,940 KB |
testcase_26 | AC | 15 ms
6,944 KB |
testcase_27 | AC | 28 ms
6,940 KB |
testcase_28 | AC | 26 ms
6,940 KB |
testcase_29 | AC | 26 ms
6,940 KB |
ソースコード
/** * @FileName a.cpp * @Author kanpurin * @Created 2021.02.21 16:52:50 **/ #include "bits/stdc++.h" using namespace std; typedef long long ll; template< int MOD > struct mint { public: long long x; mint(long long x = 0) :x((x%MOD+MOD)%MOD) {} mint(std::string &s) { long long z = 0; for (int i = 0; i < s.size(); i++) { z *= 10; z += s[i] - '0'; z %= MOD; } this->x = z; } mint& operator+=(const mint &a) { if ((x += a.x) >= MOD) x -= MOD; return *this; } mint& operator-=(const mint &a) { if ((x += MOD - a.x) >= MOD) x -= MOD; return *this; } mint& operator*=(const mint &a) { (x *= a.x) %= MOD; return *this; } mint& operator/=(const mint &a) { long long n = MOD - 2; mint u = 1, b = a; while (n > 0) { if (n & 1) { u *= b; } b *= b; n >>= 1; } return *this *= u; } mint operator+(const mint &a) const { mint res(*this); return res += a; } mint operator-() const {return mint() -= *this; } mint operator-(const mint &a) const { mint res(*this); return res -= a; } mint operator*(const mint &a) const { mint res(*this); return res *= a; } mint operator/(const mint &a) const { mint res(*this); return res /= a; } friend std::ostream& operator<<(std::ostream &os, const mint &n) { return os << n.x; } friend std::istream &operator>>(std::istream &is, mint &n) { long long x; is >> x; n = mint(x); return is; } bool operator==(const mint &a) const { return this->x == a.x; } bool operator!=(const mint &a) const { return this->x != a.x; } mint pow(long long k) const { mint ret = 1; mint p = this->x; while (k > 0) { if (k & 1) { ret *= p; } p *= p; k >>= 1; } return ret; } }; template < const int MOD , bool any = false> struct FormalPowerSeries { private: using P = FormalPowerSeries< MOD, any >; template < class T, class F = multiplies< T > > T power(T a, long long n, F op = multiplies< T >(), T e = {1}) const { assert(n >= 0); T res = e; while (n) { if (n & 1) res = op(res, a); if (n >>= 1) a = op(a, a); } return res; } template< int _MOD > void ntt(vector< mint < _MOD > >& a, bool inverse) { static vector< mint< _MOD > > dw(30), idw(30); if (dw[0] == 0) { mint< _MOD > root = 2; while (power(root, (_MOD - 1) / 2) == 1) root += 1; for (int i = 0; i < 30; i++) dw[i] = -power(root, (_MOD - 1) >> (i + 2)), idw[i] = mint<_MOD>(1) / dw[i]; } int n = a.size(); assert((n & (n - 1)) == 0); if (not inverse) { for (int m = n; m >>= 1;) { mint< _MOD > w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; i++, j++) { auto x = a[i], y = a[j] * w; if (x.x >= _MOD) x.x -= _MOD; a[i].x = x.x + y.x, a[j].x = x.x + (_MOD - y.x); } w *= dw[__builtin_ctz(++k)]; } } } else { for (int m = 1; m < n; m *= 2) { mint< _MOD > w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; i++, j++) { auto x = a[i], y = a[j]; a[i] = x + y, a[j].x = x.x + (_MOD - y.x), a[j] *= w; } w *= idw[__builtin_ctz(++k)]; } } } auto c = mint<_MOD>(1) / mint< _MOD >(inverse ? n : 1); for (auto&& e : a) e *= c; } template< int _MOD > vector< mint< _MOD > > convolution(vector< mint< _MOD > > l, vector< mint< _MOD > > r) { if (l.empty() || r.empty()) return {}; int n = l.size(), m = r.size(), sz = 1 << __lg(2 * (n + m - 1) - 1); if (min(n, m) < 30) { vector< long long > res(n + m - 1); for (int i = 0; i < n; i++) for (int j = 0; j < m; j++) res[i + j] += (l[i] * r[j]).x; return {begin(res), end(res)}; } bool eq = l == r; l.resize(sz), ntt(l, false); if (eq) r = l; else r.resize(sz), ntt(r, false); for (int i = 0; i < sz; i++) l[i] *= r[i]; ntt(l, true), l.resize(n + m - 1); return l; } P pre(const P &p, int sz) const { P ret; ret.a = vector<mint<MOD>>(p.a.begin(), p.a.begin() + min((int)p.a.size(), sz)); return ret; } public: vector<mint<MOD>> a; FormalPowerSeries(int sz = 0) { this->a.resize(sz, 0); } P resize(int k) const { return pre(*this,k); } FormalPowerSeries(std::initializer_list<mint<MOD>> v) { this->a = v; } size_t size() const { return this->a.size(); } bool operator<(const P& r) const { return this->a.size() < r.a.size(); } bool operator>(const P& r) const { return this->a.size() > r.a.size(); } P operator+(const P &a) const { return P(*this) += a; } P operator+(const long long a) const { return P(*this) += a; } P operator-(const P &a) const { return P(*this) -= a; } P operator*(const P &a) const { return P(*this) *= a; } P operator*(const long long a) const { return P(*this) *= a; } P operator/(const P &a) const { return P(*this) /= a; } P operator/(const mint<MOD> &a) const { return P(*this) /= a; } P &operator+=(const P &r) { this->a.resize(max(this->a.size(),r.size())); for(int i = 0; i < (int)r.size(); i++) this->a[i] += r.a[i]; return *this; } P &operator+=(const long long v) { if (this->a.size() == 0) this->a.resize(1,(v % MOD + MOD) % MOD); else this->a[0] += v; return *this; } P &operator-=(const P &r) { this->a.resize(max(this->a.size(),r.size())); for(int i = 0; i < (int)r.size(); i++) this->a[i] -= r.a[i]; return *this; } P &operator*=(const P &b) { if (!any) { this->a = convolution(this->a, b.a); return *this; } else { if (this->a.empty() || b.a.empty()) { this->a.clear(); return *this; } int n = this->a.size(), m = b.a.size(); static constexpr int mod0 = 998244353, mod1 = 1300234241, mod2 = 1484783617; using Mint0 = mint< mod0 >; using Mint1 = mint< mod1 >; using Mint2 = mint< mod2 >; vector< Mint0 > l0(n), r0(m); vector< Mint1 > l1(n), r1(m); vector< Mint2 > l2(n), r2(m); for (int i = 0; i < n; i++) l0[i] = this->a[i].x, l1[i] = this->a[i].x, l2[i] = this->a[i].x; for (int j = 0; j < m; j++) r0[j] = b.a[j].x, r1[j] = b.a[j].x, r2[j] = b.a[j].x; l0 = convolution(l0,r0); l1 = convolution(l1,r1); l2 = convolution(l2,r2); this->a.resize(n + m - 1); static const Mint1 im0 = Mint1(1) / Mint1(mod0); static const Mint2 im1 = Mint2(1) / Mint2(mod1), im0m1 = im1 / mod0; static const mint<MOD> m0 = mod0, m0m1 = m0 * mod1; for (int i = 0; i < n + m - 1; i++) { int y0 = l0[i].x; int y1 = (im0 * (l1[i] - y0)).x; int y2 = (im0m1 * (l2[i] - y0) - im1 * y1).x; this->a[i] = m0m1 * y2 + y0 + m0 * y1; } return *this; } } P &operator*=(const long long v) { for (int i = 0; i < this->a.size(); i++) this->a[i] *= v; return *this; } P &operator/=(const P &a) { *this *= a.inverse(); return *this; } P &operator/=(const mint<MOD> &v) { for (int i = 0; i < this->size(); i++) { this->a[i] /= v; } return *this; } P inverse(int deg = -1) const { assert(this->a.size() != 0 && this->a[0].x != 0); const int n = (int)this->a.size(); if(deg == -1) deg = n; P ret(1); ret[0] = mint<MOD>(1) / a[0]; for(int i = 1; i < deg; i <<= 1) { ret = pre((ret + ret - ret * ret * pre(*this,i << 1)),i << 1); } return pre(ret,deg); } P differential() const { const int n = (int) this->a.size(); P ret(max(0, n - 1)); for(int i = 1; i < n; i++) ret[i-1] = this->a[i] * i; return ret; } P integral() const { const int n = (int) this->a.size(); P ret(n + 1); for(int i = 0; i < n; i++) ret[i + 1] = this->a[i] / (i + 1); return ret; } P log(int deg = -1) const { assert(this->a.size() != 0 && this->a[0] == 1); const int n = (int)this->a.size(); if(deg == -1) deg = n; return pre((this->differential() * this->inverse(deg)),deg - 1).integral(); } P exp(int deg = -1) const { if (this->a.size() == 0) return P(0); assert(this->a[0] == 0); const int n = (int)this->a.size(); if(deg == -1) deg = n; P ret(1); ret.a[0] = 1; for(int i = 1; i < deg; i <<= 1) { ret = pre((ret * (pre(*this,i << 1) + 1 - ret.log(i << 1))),i << 1); } return pre(ret,deg); } P pow(long long k, int deg = -1) const { const int n = (int) this->a.size(); if(deg == -1) deg = n; for(int i = 0; i < n; i++) { if(this->a[i].x != 0) { long long rev = (mint<MOD>(1) / this->a[i]).x; P C = *this * rev; P D(n - i); for(int j = i; j < n; j++) D[j - i] = C[j]; D = (D.log() * k).exp() * power(this->a[i], k).x; P E(deg); if(i * k > deg) return E; auto S = i * k; for(int j = 0; j + S < deg && j < D.size(); j++) E[j + S] = D[j]; return E; } } return *this; } mint< MOD > &operator[](int x) { assert(0 <= x && x < (int)this->a.size()); return a[x]; } friend std::ostream &operator<<(std::ostream &os, const P &p) { os << "[ "; for (int i = 0; i < p.size(); ++i) { os << p.a[i] << " "; } os << "]"; return os; } }; constexpr int MOD = 998244353; class UnionFind { private: vector<int> par; public: UnionFind(int n) { par.resize(n, -1); } int root(int x) { if (par[x] < 0) return x; return par[x] = root(par[x]); } bool unite(int x, int y) { int rx = root(x); int ry = root(y); if (rx == ry) return false; if (size(rx) < size(ry)) swap(rx, ry); par[rx] += par[ry]; par[ry] = rx; return true; } bool same(int x, int y) { int rx = root(x); int ry = root(y); return rx == ry; } int size(int x) { return -par[root(x)]; } }; int n,m; FormalPowerSeries<MOD> S; mint<MOD> solve1(vector<int> &c) { mint<MOD> ans = 0; FormalPowerSeries<MOD> a(1); a[0] = 1; for (int j = 0; j < c.size(); j++) { FormalPowerSeries<MOD> b(c[j]+1); b[1] = 1; b[0] = 1; b = b.pow(c[j]); b[1] -= 1; b[0] -= 1; a *= b; } for (int i = m; i <= n; i++) { if ((n+i) & 1) { ans -= a[i] * S[i]; } else { ans += a[i] * S[i]; } } return ans; } mint<MOD> solve2(vector<int> &c) { mint<MOD> ans = 0; FormalPowerSeries<MOD> a(1); a[0] = 1; for (int i = 0; i < c.size(); i++) { FormalPowerSeries<MOD> b(c[i]+1); b[1] = MOD-1; b[0] = 1; b = b.pow(c[i]); for (int j = 0; j < c[i]; j++) { b[j] = b[j+1]; } b[0] += 1; b = b.resize(c[i]); a *= b; } for (int i = 0; i <= n-c.size(); i++) { ans += S[c.size()+i] * a[i]; } if (n & 1) { return -ans; } else { return ans; } } int main() { cin >> n >> m; vector<int> p(n); UnionFind uf(n); for (int i = 0; i < n; i++) { cin >> p[i]; uf.unite(i,p[i]-1); } vector<int> c; for (int i = 0; i < n; i++) { if (uf.root(i) == i) { c.push_back(uf.size(i)); } } S = FormalPowerSeries<MOD>(n+1); S[1] = 1; S = S.exp(); S[0] -= 1; mint<MOD> fact = 1; for (int i = 2; i <= m; i++) { fact *= i; } S = S.pow(m) / fact; fact = 1; for (int i = 1; i <= n; i++) { fact *= i; S[i] *= fact; } cout << solve1(c) << endl; return 0; }