結果

問題 No.1392 Don't be together
ユーザー 🍮かんプリン🍮かんプリン
提出日時 2021-02-21 15:05:15
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 12,674 bytes
コンパイル時間 2,829 ms
コンパイル使用メモリ 203,132 KB
実行使用メモリ 23,060 KB
最終ジャッジ日時 2024-09-19 13:18:14
合計ジャッジ時間 6,506 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
10,752 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 AC 2 ms
5,376 KB
testcase_06 TLE -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

/**
 *   @FileName	a.cpp
 *   @Author	kanpurin
 *   @Created	2021.02.21 15:05:10
**/

#include "bits/stdc++.h" 
using namespace std; 
typedef long long ll;

template< int MOD >
struct mint {
public:
    long long x;
    mint(long long x = 0) :x((x%MOD+MOD)%MOD) {}
    mint(std::string &s) {
        long long z = 0;
        for (int i = 0; i < s.size(); i++) {
            z *= 10;
            z += s[i] - '0';
            z %= MOD;
        }
        this->x = z;
    }
    mint& operator+=(const mint &a) {
        if ((x += a.x) >= MOD) x -= MOD;
        return *this;
    }
    mint& operator-=(const mint &a) {
        if ((x += MOD - a.x) >= MOD) x -= MOD;
        return *this;
    }
    mint& operator*=(const mint &a) {
        (x *= a.x) %= MOD;
        return *this;
    }
    mint& operator/=(const mint &a) {
        long long n = MOD - 2;
        mint u = 1, b = a;
        while (n > 0) {
            if (n & 1) {
                u *= b;
            }
            b *= b;
            n >>= 1;
        }
        return *this *= u;
    }
    mint operator+(const mint &a) const {
        mint res(*this);
        return res += a;
    }
    mint operator-() const {return mint() -= *this; }
    mint operator-(const mint &a) const {
        mint res(*this);
        return res -= a;
    }
    mint operator*(const mint &a) const {
        mint res(*this);
        return res *= a;
    }
    mint operator/(const mint &a) const {
        mint res(*this);
        return res /= a;
    }
    friend std::ostream& operator<<(std::ostream &os, const mint &n) {
        return os << n.x;
    }
    friend std::istream &operator>>(std::istream &is, mint &n) {
        long long x;
        is >> x;
        n = mint(x);
        return is;
    }
    bool operator==(const mint &a) const {
        return this->x == a.x;
    }
    bool operator!=(const mint &a) const {
        return this->x != a.x;
    }
    mint pow(long long k) const {
        mint ret = 1;
        mint p = this->x;
        while (k > 0) {
            if (k & 1) {
                ret *= p;
            }
            p *= p;
            k >>= 1;
        }
        return ret;
    }
};




template < const int MOD , bool any = false>
struct FormalPowerSeries {
private:
    using P = FormalPowerSeries< MOD, any >;
    template < class T, class F = multiplies< T > >
    T power(T a, long long n, F op = multiplies< T >(), T e = {1}) const {
        assert(n >= 0);
        T res = e;
        while (n) {
            if (n & 1) res = op(res, a);
            if (n >>= 1) a = op(a, a);
        }
        return res;
    }
    template< int _MOD >
    void ntt(vector< mint < _MOD > >& a, bool inverse) {
        static vector< mint< _MOD > > dw(30), idw(30);
        if (dw[0] == 0) {
            mint< _MOD > root = 2;
            while (power(root, (_MOD - 1) / 2) == 1) root += 1;
            for (int i = 0; i < 30; i++) dw[i] = -power(root, (_MOD - 1) >> (i + 2)), idw[i] = mint<_MOD>(1) / dw[i];
        }
        int n = a.size();
        assert((n & (n - 1)) == 0);
        if (not inverse) {
            for (int m = n; m >>= 1;) {
                mint< _MOD > w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; i++, j++) {
                        auto x = a[i], y = a[j] * w;
                        if (x.x >= _MOD) x.x -= _MOD;
                        a[i].x = x.x + y.x, a[j].x = x.x + (_MOD - y.x);
                    }
                    w *= dw[__builtin_ctz(++k)];
                }
            }
        } else {
            for (int m = 1; m < n; m *= 2) {
                mint< _MOD > w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; i++, j++) {
                        auto x = a[i], y = a[j];
                        a[i] = x + y, a[j].x = x.x + (_MOD - y.x), a[j] *= w;
                    }
                    w *= idw[__builtin_ctz(++k)];
                }
            }
        }
        auto c = mint<_MOD>(1) / mint< _MOD >(inverse ? n : 1);
        for (auto&& e : a) e *= c;
    }
    template< int _MOD >
    vector< mint< _MOD > > convolution(vector< mint< _MOD > > l, vector< mint< _MOD > > r) {
        if (l.empty() || r.empty()) return {};
        int n = l.size(), m = r.size(), sz = 1 << __lg(2 * (n + m - 1) - 1);
        if (min(n, m) < 30) {
            vector< long long > res(n + m - 1);
            for (int i = 0; i < n; i++)
                for (int j = 0; j < m; j++) res[i + j] += (l[i] * r[j]).x;
            return {begin(res), end(res)};
        }
        bool eq = l == r;
        l.resize(sz), ntt(l, false);
        if (eq) r = l;
        else r.resize(sz), ntt(r, false);
        for (int i = 0; i < sz; i++) l[i] *= r[i];
        ntt(l, true), l.resize(n + m - 1);
        return l;
    }
    P pre(const P &p, int sz) const {
        P ret;
        ret.a = vector<mint<MOD>>(p.a.begin(), p.a.begin() + min((int)p.a.size(), sz));
        return ret;
    }
public:
    vector<mint<MOD>> a;
    FormalPowerSeries(int sz = 0) {
        this->a.resize(sz, 0);
    }
    
    P resize(int k) const {
        return pre(*this,k);
    }
    FormalPowerSeries(std::initializer_list<mint<MOD>> v) {
        this->a = v;
    }
    
    size_t size() const { return this->a.size(); }
    bool operator<(const P& r) const { return this->a.size() < r.a.size(); }
    bool operator>(const P& r) const { return this->a.size() > r.a.size(); }
    P operator+(const P &a) const { return P(*this) += a; }
    P operator+(const long long a) const { return P(*this) += a; }
    P operator-(const P &a) const { return P(*this) -= a; }
    P operator*(const P &a) const { return P(*this) *= a; }
    P operator*(const long long a) const { return P(*this) *= a; }
    P operator/(const P &a) const { return P(*this) /= a; }
    P operator/(const mint<MOD> &a) const { return P(*this) /= a; }
    P &operator+=(const P &r) {
        this->a.resize(max(this->a.size(),r.size()));
        for(int i = 0; i < (int)r.size(); i++) this->a[i] += r.a[i];
        return *this;
    }
    
    P &operator+=(const long long v) {
        if (this->a.size() == 0) this->a.resize(1,(v % MOD + MOD) % MOD);
        else this->a[0] += v;
        return *this;
    }
    P &operator-=(const P &r) {
        this->a.resize(max(this->a.size(),r.size()));
        for(int i = 0; i < (int)r.size(); i++) this->a[i] -= r.a[i];
        return *this;
    }
    P &operator*=(const P &b) {
        if (!any) {
            this->a = convolution(this->a, b.a);
            return *this;
        }
        else {
            if (this->a.empty() || b.a.empty()) {
                this->a.clear();
                return *this;
            }
            int n = this->a.size(), m = b.a.size();
            static constexpr int mod0 = 998244353, mod1 = 1300234241, mod2 = 1484783617;
            using Mint0 = mint< mod0 >;
            using Mint1 = mint< mod1 >;
            using Mint2 = mint< mod2 >;
            vector< Mint0 > l0(n), r0(m);
            vector< Mint1 > l1(n), r1(m);
            vector< Mint2 > l2(n), r2(m);
            for (int i = 0; i < n; i++) l0[i] = this->a[i].x, l1[i] = this->a[i].x, l2[i] = this->a[i].x;
            for (int j = 0; j < m; j++) r0[j] = b.a[j].x, r1[j] = b.a[j].x, r2[j] = b.a[j].x;
            l0 = convolution(l0,r0);
            l1 = convolution(l1,r1);
            l2 = convolution(l2,r2);
            this->a.resize(n + m - 1);
            static const Mint1 im0 = Mint1(1) / Mint1(mod0);
            static const Mint2 im1 = Mint2(1) / Mint2(mod1), im0m1 = im1 / mod0;
            static const mint<MOD> m0 = mod0, m0m1 = m0 * mod1;
            for (int i = 0; i < n + m - 1; i++) {
                int y0 = l0[i].x;
                int y1 = (im0 * (l1[i] - y0)).x;
                int y2 = (im0m1 * (l2[i] - y0) - im1 * y1).x;
                this->a[i] = m0m1 * y2 + y0 + m0 * y1;
            }
            return *this;
        }
    }
    P &operator*=(const long long v) {
        for (int i = 0; i < this->a.size(); i++) this->a[i] *= v;
        return *this;
    }
    
    P &operator/=(const P &a) {
        *this *= a.inverse();
        return *this;
    }
    
    P &operator/=(const mint<MOD> &v) {
        for (int i = 0; i < this->size(); i++) {
            this->a[i] /= v;
        }
        return *this;
    }
    
    P inverse(int deg = -1) const {
        assert(this->a.size() != 0 && this->a[0].x != 0);
        const int n = (int)this->a.size();
        if(deg == -1) deg = n;
        P ret(1);
        ret[0] = mint<MOD>(1) / a[0];
        for(int i = 1; i < deg; i <<= 1) {
            ret = pre((ret + ret - ret * ret * pre(*this,i << 1)),i << 1);
        }
        return pre(ret,deg);
    }
    
    P differential() const {
        const int n = (int) this->a.size();
        P ret(max(0, n - 1));
        for(int i = 1; i < n; i++) ret[i-1] = this->a[i] * i;
        return ret;
    }
    
    P integral() const {
        const int n = (int) this->a.size();
        P ret(n + 1);
        for(int i = 0; i < n; i++) ret[i + 1] = this->a[i] / (i + 1);
        return ret;
    }
    
    P log(int deg = -1) const {
        assert(this->a.size() != 0 && this->a[0] == 1);
        const int n = (int)this->a.size();
        if(deg == -1) deg = n;
        return pre((this->differential() * this->inverse(deg)),deg - 1).integral();
    }
    
    P exp(int deg = -1) const {
        if (this->a.size() == 0) return P(0);
        assert(this->a[0] == 0);
        const int n = (int)this->a.size();
        if(deg == -1) deg = n;
        P ret(1);
        ret.a[0] = 1;
        for(int i = 1; i < deg; i <<= 1) {
            ret = pre((ret * (pre(*this,i << 1) + 1 - ret.log(i << 1))),i << 1);
        }
        return pre(ret,deg);
    }
    
    P pow(long long k, int deg = -1) const {
        const int n = (int) this->a.size();
        if(deg == -1) deg = n;
        for(int i = 0; i < n; i++) {
            if(this->a[i].x != 0) {
                long long rev = (mint<MOD>(1) / this->a[i]).x;
                P C = *this * rev;
                P D(n - i);
                for(int j = i; j < n; j++) D[j - i] = C[j];
                D = (D.log() * k).exp() * power(this->a[i], k).x;
                P E(deg);
                if(i * k > deg) return E;
                auto S = i * k;
                for(int j = 0; j + S < deg && j < D.size(); j++) E[j + S] = D[j];
                return E;
            }
        }
        return *this;
    }
    mint< MOD > &operator[](int x) {
        assert(0 <= x && x < (int)this->a.size());
        return a[x];
    }
    friend std::ostream &operator<<(std::ostream &os, const P &p) {
        os << "[ ";
        for (int i = 0; i < p.size(); ++i) {
            os << p.a[i] << " ";
        }
        os << "]";
        return os;
    }
};
constexpr int MOD = 998244353;

class UnionFind {
private:
    vector<int> par;
public:
    UnionFind(int n) {
        par.resize(n, -1);
    }
    int root(int x) {
        if (par[x] < 0) return x;
        return par[x] = root(par[x]);
    }
    bool unite(int x, int y) {
        int rx = root(x);
        int ry = root(y);
        if (rx == ry) return false;
        if (size(rx) < size(ry)) swap(rx, ry);
        par[rx] += par[ry];
        par[ry] = rx;
        return true;
    }
    bool same(int x, int y) {
        int rx = root(x);
        int ry = root(y);
        return rx == ry;
    }
    int size(int x) {
        return -par[root(x)];
    }
};
int main() {
    int n,m;cin >> n >> m;
    vector<int> p(n);
    UnionFind uf(n);
    for (int i = 0; i < n; i++) {
        cin >> p[i];
        uf.unite(i,p[i]-1);
    }
    vector<int> c;
    for (int i = 0; i < n; i++) {
        if (uf.root(i) == i) {
            c.push_back(uf.size(i));
        }
    }
    FormalPowerSeries<MOD> S(n+1);
    S[1] = 1;
    S = S.exp();
    S[0] -= 1;
    mint<MOD> fact = 1;
    for (int i = 2; i <= m; i++) {
        fact *= i;
    }
    S = S.pow(m) / fact;
    fact = 1;
    for (int i = 1; i <= n; i++) {
        fact *= i;
        S[i] *= fact;
    }
    mint<MOD> ans = 0;
    for (int i = m; i <= n; i++) {
        FormalPowerSeries<MOD> a(i);
        a[0] = 1;
        for (int j = 0; j < c.size(); j++) {
            FormalPowerSeries<MOD> b(i);
            b[1] = 1;
            b[0] = 1;
            b = b.pow(c[j]);
            b[1] -= c[j]-1;
            b[0] -= 1;
            a *= b;
        }
        if ((n+i) & 1) {
            ans -= a[i] * S[i];
        }
        else {
            ans += a[i] * S[i];
        }
    }
    cout << ans << endl;
    return 0;
}
0