結果
問題 | No.1392 Don't be together |
ユーザー |
![]() |
提出日時 | 2021-02-12 22:37:21 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 733 ms / 2,000 ms |
コード長 | 5,908 bytes |
コンパイル時間 | 2,253 ms |
コンパイル使用メモリ | 205,728 KB |
最終ジャッジ日時 | 2025-01-18 19:09:54 |
ジャッジサーバーID (参考情報) |
judge5 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 27 |
ソースコード
#include <bits/stdc++.h>#define FOR(i, n, m) for(ll i = (n); i < (ll)(m); i++)#define REP(i, n) FOR(i, 0, n)#define ALL(v) v.begin(), v.end()#define pb push_backusing namespace std;using ll = long long;using ld = long double;using P = pair<ll, ll>;constexpr ll inf = 1000000000;constexpr ll mod = 998244353;constexpr long double eps = 1e-6;template<typename T1, typename T2>ostream& operator<<(ostream& os, pair<T1, T2> p) {os << to_string(p.first) << " " << to_string(p.second);return os;}template<typename T>ostream& operator<<(ostream& os, vector<T>& v) {REP(i, v.size()) {if(i) os << " ";os << v[i];}return os;}struct modint {ll n;public:modint(const ll n = 0) : n((n % mod + mod) % mod) {}static modint pow(modint a, int m) {modint r = 1;while(m > 0) {if(m & 1) { r *= a; }a = (a * a); m /= 2;}return r;}modint &operator++() { *this += 1; return *this; }modint &operator--() { *this -= 1; return *this; }modint operator++(int) { modint ret = *this; *this += 1; return ret; }modint operator--(int) { modint ret = *this; *this -= 1; return ret; }modint operator~() const { return (this -> pow(n, mod - 2)); } // inversefriend bool operator==(const modint& lhs, const modint& rhs) {return lhs.n == rhs.n;}friend bool operator<(const modint& lhs, const modint& rhs) {return lhs.n < rhs.n;}friend bool operator>(const modint& lhs, const modint& rhs) {return lhs.n > rhs.n;}friend modint &operator+=(modint& lhs, const modint& rhs) {lhs.n += rhs.n;if (lhs.n >= mod) lhs.n -= mod;return lhs;}friend modint &operator-=(modint& lhs, const modint& rhs) {lhs.n -= rhs.n;if (lhs.n < 0) lhs.n += mod;return lhs;}friend modint &operator*=(modint& lhs, const modint& rhs) {lhs.n = (lhs.n * rhs.n) % mod;return lhs;}friend modint &operator/=(modint& lhs, const modint& rhs) {lhs.n = (lhs.n * (~rhs).n) % mod;return lhs;}friend modint operator+(const modint& lhs, const modint& rhs) {return modint(lhs.n + rhs.n);}friend modint operator-(const modint& lhs, const modint& rhs) {return modint(lhs.n - rhs.n);}friend modint operator*(const modint& lhs, const modint& rhs) {return modint(lhs.n * rhs.n);}friend modint operator/(const modint& lhs, const modint& rhs) {return modint(lhs.n * (~rhs).n);}};istream& operator>>(istream& is, modint m) { is >> m.n; return is; }ostream& operator<<(ostream& os, modint m) { os << m.n; return os; }#define MAX_N 1010101long long extgcd(long long a, long long b, long long& x, long long& y) {long long d = a;if (b != 0) {d = extgcd(b, a % b, y, x);y -= (a / b) * x;} else {x = 1; y = 0;}return d;}long long mod_inverse(long long a, long long m) {long long x, y;if(extgcd(a, m, x, y) == 1) return (m + x % m) % m;else return -1;}vector<long long> fact(MAX_N+1, inf);long long mod_fact(long long n, long long& e) {if(fact[0] == inf) {fact[0]=1;if(MAX_N != 0) fact[1]=1;for(ll i = 2; i <= MAX_N; ++i) {fact[i] = (fact[i-1] * i) % mod;}}e = 0;if(n == 0) return 1;long long res = mod_fact(n / mod, e);e += n / mod;if((n / mod) % 2 != 0) return (res * (mod - fact[n % mod])) % mod;return (res * fact[n % mod]) % mod;}// return nCklong long mod_comb(long long n, long long k) {if(n < 0 || k < 0 || n < k) return 0;long long e1, e2, e3;long long a1 = mod_fact(n, e1), a2 = mod_fact(k, e2), a3 = mod_fact(n - k, e3);if(e1 > e2 + e3) return 0;return (a1 * mod_inverse((a2 * a3) % mod, mod)) % mod;}using mi = modint;mi mod_pow(mi a, ll n) {mi ret = 1;mi tmp = a;while(n > 0) {if(n % 2) ret *= tmp;tmp = tmp * tmp;n /= 2;}return ret;}ll mod_pow(ll a, ll n, ll m) {ll ret = 1;ll tmp = a;while(n > 0) {if(n % 2) ret *= tmp;ret %= m;tmp = tmp * tmp;tmp %= m;n /= 2;}return ret % m;}ll gcd(ll a, ll b) {if (b == 0) return a;return gcd(b, a % b);}vector<vector<mi>> memo(5001, vector<mi>(5001));mi dfs(ll m, ll k) {if(memo[m][k].n != -1) {return memo[m][k];}if(k == 1) {return memo[m][k] = m;} else if(k == 2) {return memo[m][k] = m * (m - 1);} else if(k == 3) {return memo[m][k] = ((m * (m - 1)) % mod) * (m - 2);} else {return memo[m][k] = (m - 1) * dfs(m, k - 2) + (m - 2) * dfs(m, k - 1);}}int main() {cin.tie(0);ios::sync_with_stdio(false);REP(i, 5001) REP(j, 5001) {memo[i][j].n = -1;}int n, m;cin >> n >> m;vector<int> p(n);REP(i, n) {cin >> p[i];p[i] -= 1;}if(m == 1) {cout << 0 << endl;return 0;}vector<bool> used(n, false);vector<int> cycle;REP(i, n) {if(not used[i]) {int cnt = 0;int pos = i;while(not used[pos]) {used[pos] = true;cnt += 1;pos = p[pos];}cycle.pb(cnt);}}mi ans = 0;vector<mi> dp(m + 1, 1);for(int i = 2; i <= m; i += 1) {for(auto c: cycle) {dp[i] *= dfs(i, c);}if((m - i) % 2) ans -= mod_comb(m, i) * dp[i];else ans += mod_comb(m, i) * dp[i];// cout << dp[i] << endl;}mi fct = 1;FOR(i, 1, m + 1) fct *= i;cout << ans / fct << endl;return 0;}