結果
問題 | No.1302 Random Tree Score |
ユーザー | masayoshi361 |
提出日時 | 2020-11-30 08:42:12 |
言語 | C++14 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 437 ms / 3,000 ms |
コード長 | 16,703 bytes |
コンパイル時間 | 2,957 ms |
コンパイル使用メモリ | 202,264 KB |
実行使用メモリ | 11,268 KB |
最終ジャッジ日時 | 2024-09-13 02:14:32 |
合計ジャッジ時間 | 7,481 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,816 KB |
testcase_01 | AC | 2 ms
6,944 KB |
testcase_02 | AC | 100 ms
6,940 KB |
testcase_03 | AC | 211 ms
7,668 KB |
testcase_04 | AC | 100 ms
6,940 KB |
testcase_05 | AC | 435 ms
11,268 KB |
testcase_06 | AC | 435 ms
11,216 KB |
testcase_07 | AC | 100 ms
6,940 KB |
testcase_08 | AC | 215 ms
8,152 KB |
testcase_09 | AC | 437 ms
11,260 KB |
testcase_10 | AC | 424 ms
10,392 KB |
testcase_11 | AC | 97 ms
6,944 KB |
testcase_12 | AC | 430 ms
10,728 KB |
testcase_13 | AC | 2 ms
6,944 KB |
testcase_14 | AC | 437 ms
11,136 KB |
testcase_15 | AC | 437 ms
11,264 KB |
testcase_16 | AC | 2 ms
6,944 KB |
ソースコード
#line 1 "verify/FPS.power.test.cpp" #define PROBLEM "https://yukicoder.me/problems/no/1302" #line 1 "library/template/template.cpp" /* #region header */ #pragma GCC optimize("Ofast") #include <bits/stdc++.h> using namespace std; // types using ll = long long; using ull = unsigned long long; using ld = long double; typedef pair<ll, ll> Pl; typedef pair<int, int> Pi; typedef vector<ll> vl; typedef vector<int> vi; typedef vector<char> vc; template <typename T> using mat = vector<vector<T>>; typedef vector<vector<int>> vvi; typedef vector<vector<long long>> vvl; typedef vector<vector<char>> vvc; // abreviations #define all(x) (x).begin(), (x).end() #define rall(x) (x).rbegin(), (x).rend() #define rep_(i, a_, b_, a, b, ...) for (ll i = (a), max_i = (b); i < max_i; i++) #define rep(i, ...) rep_(i, __VA_ARGS__, __VA_ARGS__, 0, __VA_ARGS__) #define rrep_(i, a_, b_, a, b, ...) \ for (ll i = (b - 1), min_i = (a); i >= min_i; i--) #define rrep(i, ...) rrep_(i, __VA_ARGS__, __VA_ARGS__, 0, __VA_ARGS__) #define srep(i, a, b, c) for (ll i = (a), max_i = (b); i < max_i; i += c) #define SZ(x) ((int)(x).size()) #define pb(x) push_back(x) #define eb(x) emplace_back(x) #define mp make_pair //入出力 #define print(x) cout << x << endl template <class T> ostream& operator<<(ostream& os, const vector<T>& v) { for (auto& e : v) cout << e << " "; cout << endl; return os; } void scan(int& a) { cin >> a; } void scan(long long& a) { cin >> a; } void scan(char& a) { cin >> a; } void scan(double& a) { cin >> a; } void scan(string& a) { cin >> a; } template <class T> void scan(vector<T>& a) { for (auto& i : a) scan(i); } #define vsum(x) accumulate(all(x), 0LL) #define vmax(a) *max_element(all(a)) #define vmin(a) *min_element(all(a)) #define lb(c, x) distance((c).begin(), lower_bound(all(c), (x))) #define ub(c, x) distance((c).begin(), upper_bound(all(c), (x))) // functions // gcd(0, x) fails. ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; } ll lcm(ll a, ll b) { return a / gcd(a, b) * b; } template <class T> bool chmax(T& a, const T& b) { if (a < b) { a = b; return 1; } return 0; } template <class T> bool chmin(T& a, const T& b) { if (b < a) { a = b; return 1; } return 0; } template <typename T> T mypow(T x, ll n) { T ret = 1; while (n > 0) { if (n & 1) (ret *= x); (x *= x); n >>= 1; } return ret; } ll modpow(ll x, ll n, const ll mod) { ll ret = 1; while (n > 0) { if (n & 1) (ret *= x); (x *= x); n >>= 1; x %= mod; ret %= mod; } return ret; } uint64_t my_rand(void) { static uint64_t x = 88172645463325252ULL; x = x ^ (x << 13); x = x ^ (x >> 7); return x = x ^ (x << 17); } int popcnt(ull x) { return __builtin_popcountll(x); } template <typename T> vector<int> IOTA(vector<T> a) { int n = a.size(); vector<int> id(n); iota(all(id), 0); sort(all(id), [&](int i, int j) { return a[i] < a[j]; }); return id; } struct Timer { clock_t start_time; void start() { start_time = clock(); } int lap() { // return x ms. return (clock() - start_time) * 1000 / CLOCKS_PER_SEC; } }; /* #endregion*/ // constant #define inf 1000000000ll #define INF 4000000004000000000LL #define endl '\n' const long double eps = 0.000000000000001; const long double PI = 3.141592653589793; #line 3 "verify/FPS.power.test.cpp" // library #line 1 "library/convolution/NTT.cpp" template <typename Mint> struct NTT { vector<Mint> dw, idw; int max_base; Mint root; NTT() { const unsigned Mod = Mint::get_mod(); assert(Mod >= 3 && Mod % 2 == 1); auto tmp = Mod - 1; max_base = 0; while (tmp % 2 == 0) tmp >>= 1, max_base++; root = 2; while (root.pow((Mod - 1) >> 1) == 1) root += 1; assert(root.pow(Mod - 1) == 1); dw.resize(max_base); idw.resize(max_base); for (int i = 0; i < max_base; i++) { dw[i] = -root.pow((Mod - 1) >> (i + 2)); idw[i] = Mint(1) / dw[i]; } } void ntt(vector<Mint>& a) { const int n = (int)a.size(); assert((n & (n - 1)) == 0); assert(__builtin_ctz(n) <= max_base); for (int m = n; m >>= 1;) { Mint w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; ++i, ++j) { auto x = a[i], y = a[j] * w; a[i] = x + y, a[j] = x - y; } w *= dw[__builtin_ctz(++k)]; } } } void intt(vector<Mint>& a, bool f = true) { const int n = (int)a.size(); assert((n & (n - 1)) == 0); assert(__builtin_ctz(n) <= max_base); for (int m = 1; m < n; m *= 2) { Mint w = 1; for (int s = 0, k = 0; s < n; s += 2 * m) { for (int i = s, j = s + m; i < s + m; ++i, ++j) { auto x = a[i], y = a[j]; a[i] = x + y, a[j] = (x - y) * w; } w *= idw[__builtin_ctz(++k)]; } } if (f) { Mint inv_sz = Mint(1) / n; for (int i = 0; i < n; i++) a[i] *= inv_sz; } } vector<Mint> multiply(vector<Mint> a, vector<Mint> b) { int need = a.size() + b.size() - 1; int nbase = 1; while ((1 << nbase) < need) nbase++; int sz = 1 << nbase; a.resize(sz, 0); b.resize(sz, 0); ntt(a); ntt(b); Mint inv_sz = Mint(1) / sz; for (int i = 0; i < sz; i++) a[i] *= b[i] * inv_sz; intt(a, false); a.resize(need); return a; } }; #line 1 "library/math/FormalPowerSeries.cpp" template <typename T> struct FormalPowerSeries : vector<T> { using vector<T>::vector; using P = FormalPowerSeries; using MULT = function<P(P, P)>; static MULT& get_mult() { static MULT mult = nullptr; return mult; } static void set_fft(MULT f) { get_mult() = f; } // 末尾の0を消す void shrink() { while (this->size() && this->back() == T(0)) this->pop_back(); } P operator+(const P& r) const { return P(*this) += r; } P operator+(const T& v) const { return P(*this) += v; } P operator-(const P& r) const { return P(*this) -= r; } P operator-(const T& v) const { return P(*this) -= v; } P operator*(const P& r) const { return P(*this) *= r; } P operator*(const T& v) const { return P(*this) *= v; } P operator/(const P& r) const { return P(*this) /= r; } P operator%(const P& r) const { return P(*this) %= r; } P& operator+=(const P& r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < r.size(); i++) (*this)[i] += r[i]; return *this; } P& operator+=(const T& r) { if (this->empty()) this->resize(1); (*this)[0] += r; return *this; } P& operator-=(const P& r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < r.size(); i++) (*this)[i] -= r[i]; shrink(); return *this; } P& operator-=(const T& r) { if (this->empty()) this->resize(1); (*this)[0] -= r; shrink(); return *this; } P& operator*=(const T& v) { const int n = (int)this->size(); for (int k = 0; k < n; k++) (*this)[k] *= v; return *this; } P& operator*=(const P& r) { if (this->empty() || r.empty()) { this->clear(); return *this; } assert(get_mult() != nullptr); return *this = get_mult()(*this, r); } P& operator%=(const P& r) { return *this -= *this / r * r; } P operator-() const { P ret(this->size()); for (int i = 0; i < this->size(); i++) ret[i] = -(*this)[i]; return ret; } P& operator/=(const P& r) { if (this->size() < r.size()) { this->clear(); return *this; } int n = this->size() - r.size() + 1; return *this = (rev().pre(n) * r.rev().inv(n)).pre(n).rev(n); } P pre(int sz) const { return P(begin(*this), begin(*this) + min((int)this->size(), sz)); } // f/x^sz P operator>>(int sz) const { if (this->size() <= sz) return {}; P ret(*this); ret.erase(ret.begin(), ret.begin() + sz); return ret; } // f*x^sz P operator<<(int sz) const { P ret(*this); ret.insert(ret.begin(), sz, T(0)); return ret; } // 反転 P rev(int deg = -1) const { P ret(*this); if (deg != -1) ret.resize(deg, T(0)); reverse(begin(ret), end(ret)); return ret; } //微分 P diff() const { const int n = (int)this->size(); P ret(max(0, n - 1)); for (int i = 1; i < n; i++) ret[i - 1] = (*this)[i] * T(i); return ret; } // 積分 P integral() const { const int n = (int)this->size(); P ret(n + 1); ret[0] = T(0); for (int i = 0; i < n; i++) ret[i + 1] = (*this)[i] / T(i + 1); return ret; } // 1/fのdeg項 // F(0) must not be 0 P inv(int deg = -1) const { assert(((*this)[0]) != T(0)); const int n = (int)this->size(); if (deg == -1) deg = n; P ret({T(1) / (*this)[0]}); for (int i = 1; i < deg; i <<= 1) { ret = (ret + ret - ret * ret * pre(i << 1)).pre(i << 1); } return ret.pre(deg); } // F(0) must be 1 P log(int deg = -1) const { assert((*this)[0] == 1); const int n = (int)this->size(); if (deg == -1) deg = n; return (this->diff() * this->inv(deg)).pre(deg - 1).integral(); } P sqrt(int deg = -1) const { const int n = (int)this->size(); if (deg == -1) deg = n; if ((*this)[0] == T(0)) { for (int i = 1; i < n; i++) { if ((*this)[i] != T(0)) { if (i & 1) return {}; if (deg - i / 2 <= 0) break; auto ret = (*this >> i).sqrt(deg - i / 2) << (i / 2); if (ret.size() < deg) ret.resize(deg, T(0)); return ret; } } return P(deg, 0); } P ret({T(1)}); T inv2 = T(1) / T(2); for (int i = 1; i < deg; i <<= 1) { ret = (ret + pre(i << 1) * ret.inv(i << 1)) * inv2; } return ret.pre(deg); } // F(0) must be 0 P exp(int deg = -1) const { assert((*this)[0] == T(0)); const int n = (int)this->size(); if (deg == -1) deg = n; P ret({T(1)}); for (int i = 1; i < deg; i <<= 1) { ret = (ret * (pre(i << 1) + T(1) - ret.log(i << 1))).pre(i << 1); } return ret.pre(deg); } P pow(int64_t k, int deg = -1) const { const int n = (int)this->size(); if (deg == -1) deg = n; for (int i = 0; i < n; i++) { if ((*this)[i] != T(0)) { T rev = T(1) / (*this)[i]; P C(*this * rev); P D(n - i); for (int j = i; j < n; j++) D[j - i] = C[j]; D = (D.log() * k).exp() * (*this)[i].pow(k); P E(deg); if (i * k > deg) return E; auto S = i * k; for (int j = 0; j + S < deg && j < D.size(); j++) E[j + S] = D[j]; return E; } } return *this; } //代入 T eval(T x) const { T r = 0, w = 1; for (auto& v : *this) { r += w * v; w *= x; } return r; } }; // NTT<mint> ntt; // FPS mult_ntt(const FPS::P& a, const FPS::P& b) { // auto ret = ntt.multiply(a, b); // return FPS::P(ret.begin(), ret.end()); // } // FPS mult(const FPS::P& a, const FPS::P& b) { // FPS c(a.size() + b.size() - 1); // rep(i, a.size()) rep(j, b.size()) { c[i + j] += a[i] * b[j]; } // return c; // } #line 1 "library/math/combination.cpp" /** * @brief Combination(P, C, H, Stirling number, Bell number) * @docs docs/Combination.md */ template <typename T> struct Combination { vector<T> _fact, _rfact, _inv; Combination(int sz) : _fact(sz + 1), _rfact(sz + 1), _inv(sz + 1) { _fact[0] = _rfact[sz] = _inv[0] = 1; for (int i = 1; i <= sz; i++) _fact[i] = _fact[i - 1] * i; _rfact[sz] /= _fact[sz]; for (int i = sz - 1; i >= 0; i--) _rfact[i] = _rfact[i + 1] * (i + 1); for (int i = 1; i <= sz; i++) _inv[i] = _rfact[i] * _fact[i - 1]; } inline T fact(int k) const { return _fact[k]; } inline T rfact(int k) const { return _rfact[k]; } inline T inv(int k) const { return _inv[k]; } T P(int n, int r) const { if (r < 0 || n < r) return 0; return fact(n) * rfact(n - r); } T C(int p, int q) const { if (q < 0 || p < q) return 0; return fact(p) * rfact(q) * rfact(p - q); } T H(int n, int r) const { if (n < 0 || r < 0) return (0); return r == 0 ? 1 : C(n + r - 1, r); } // O(klog(n)) // n個の区別できる玉をk個のグループに分割する場合の数(グループのサイズは1以上) T Stirling(int n, int k) { T res = 0; rep(i, k + 1) { res += (T)((k - i) % 2 ? -1 : 1) * C(k, i) * mypow<T>(i, n); } return res / _fact[k]; } // O(klog(n)) // n個の区別できる玉をk個のグループに分割する場合の数(グループのサイズは0以上) // もしくは、k個以下の玉の一個以上入ったグループに分けると考えてもいい T Bell(int n, int k) { if (n < k) k = n; vector<T> sm(k + 1); sm[0] = 1; rep(j, 1, k + 1) { sm[j] = sm[j - 1] + (T)(j % 2 ? -1 : 1) / _fact[j]; } T res = 0; rep(i, k + 1) { res += mypow<T>(i, n) / _fact[i] * sm[k - i]; } return res; } }; #line 1 "library/mod/modint.cpp" template <int mod> struct modint { int x; modint() : x(0) {} modint(long long y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {} modint& operator+=(const modint& p) { if ((x += p.x) >= mod) x -= mod; return *this; } modint& operator-=(const modint& p) { if ((x += mod - p.x) >= mod) x -= mod; return *this; } modint& operator*=(const modint& p) { x = (int)(1LL * x * p.x % mod); return *this; } modint& operator/=(const modint& p) { *this *= p.inverse(); return *this; } modint operator-() const { return modint(-x); } modint operator+(const modint& p) const { return modint(*this) += p; } modint operator-(const modint& p) const { return modint(*this) -= p; } modint operator*(const modint& p) const { return modint(*this) *= p; } modint operator/(const modint& p) const { return modint(*this) /= p; } bool operator==(const modint& p) const { return x == p.x; } bool operator!=(const modint& p) const { return x != p.x; } modint inverse() const { int a = x, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; swap(a -= t * b, b); swap(u -= t * v, v); } return modint(u); } modint pow(int64_t n) const { modint ret(1), mul(x); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } friend ostream& operator<<(ostream& os, const modint& p) { return os << p.x; } friend istream& operator>>(istream& is, modint& a) { long long t; is >> t; a = modint<mod>(t); return (is); } static int get_mod() { return mod; } inline int get() { return x; } }; #line 8 "verify/FPS.power.test.cpp" using mint = modint<998244353>; using FPS = FormalPowerSeries<mint>; NTT<mint> ntt; FPS mult_ntt(const FPS::P& a, const FPS::P& b) { auto ret = ntt.multiply(a, b); return FPS::P(ret.begin(), ret.end()); } // FPS::set_fft(mult_ntt); in main int main() { FPS::set_fft(mult_ntt); int n; cin >> n; Combination<mint> comb(n); FPS f(n + 1); rep(i, n + 1) { f[i] = (mint)(i + 1) / comb.fact(i); } f = f.log(n) * n; f = f.exp(n); print(f[n - 2] * comb.fact(n - 2) / mypow<mint>(n, n - 2)); }