結果
問題 |
No.3160 Party Game
|
ユーザー |
![]() |
提出日時 | 2025-05-23 20:33:38 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
RE
|
実行時間 | - |
コード長 | 8,604 bytes |
コンパイル時間 | 7,207 ms |
コンパイル使用メモリ | 333,828 KB |
実行使用メモリ | 61,964 KB |
最終ジャッジ日時 | 2025-05-27 22:03:27 |
合計ジャッジ時間 | 15,528 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 7 RE * 31 |
ソースコード
#include <atcoder/all> #include <bits/stdc++.h> #define rep(i, a, b) for (ll i = (ll)(a); i < (ll)(b); i++) using namespace atcoder; using namespace std; typedef long long ll; using mint = modint998244353; struct Comb { vector<mint> fact, ifact; int MAX_COM; Comb() {} Comb(int n, int mod) { MAX_COM = n; init(mod, MAX_COM); } void init(long long MOD, long long MAX_COM) { int n = MAX_COM; assert(n < MOD); fact = vector<mint>(n + 1); ifact = vector<mint>(n + 1); fact[0] = 1; for (int i = 1; i <= n; ++i) fact[i] = fact[i - 1] * i; ifact[n] = fact[n].inv(); for (int i = n; i >= 1; --i) ifact[i - 1] = ifact[i] * i; } mint operator()(long long n, long long k) { if (k < 0 || k > n) return 0; return fact[n] * ifact[k] * ifact[n - k]; } }; Comb comb(5000010, 998244353); vector<mint> p, ps; template <typename T, class Conv = void> struct FPS_Base { vector<T> a; FPS_Base() {} FPS_Base(const vector<T> &a_) : a(a_) { normalize(); } FPS_Base(int n) { a.assign(n, T(0)); } void normalize() { while (!a.empty() && a.back() == T(0)) a.pop_back(); } int size() const { return a.size(); } T operator[](int i) const { return (i < a.size() ? a[i] : T(0)); } FPS_Base operator+(const FPS_Base &rhs) const { int n = max(a.size(), rhs.a.size()); vector<T> res(n, T(0)); rep(i, 0, n) res[i] = (*this)[i] + rhs[i]; return FPS_Base(res); } FPS_Base operator-(const FPS_Base &rhs) const { int n = max(a.size(), rhs.a.size()); vector<T> res(n, T(0)); rep(i, 0, n) res[i] = (*this)[i] - rhs[i]; return FPS_Base(res); } FPS_Base operator*(const FPS_Base &rhs) const { vector<T> res; if constexpr (is_same<Conv, void>::value) { // res = convolution1000000007(a, rhs.a); } else { res = convolution(a, rhs.a); } return FPS_Base(res); } FPS_Base operator*(const T &c) const { vector<T> res = a; for (auto &x : res) x = x * c; return FPS_Base(res); } FPS_Base truncate(int deg) const { vector<T> res(min((int)a.size(), deg)); rep(i, 0, res.size()) res[i] = a[i]; return FPS_Base(res); } FPS_Base rev(int deg = -1) const { vector<T> res = a; reverse(res.begin(), res.end()); if (deg != -1) { if ((int)res.size() < deg) res.resize(deg, T(0)); else res = vector<T>(res.begin(), res.begin() + deg); } return FPS_Base(res); } pair<FPS_Base, FPS_Base> divmod(const FPS_Base &rhs) const { int n = a.size(), m = rhs.a.size(); if (n < m) return {FPS_Base({T(0)}), *this}; FPS_Base rev_f = this->rev(); FPS_Base rev_g = rhs.rev(); FPS_Base q = (rev_f.inv(n - m + 1) * rev_g).truncate(n - m + 1).rev(n - m + 1); FPS_Base r = (*this) - (rhs * q); r.normalize(); return {q, r}; } FPS_Base operator/(const FPS_Base &rhs) const { return divmod(rhs).first; } FPS_Base operator%(const FPS_Base &rhs) const { return divmod(rhs).second; } FPS_Base shift(int shift) const { if (shift >= 0) { if ((int)a.size() <= shift) return FPS_Base(); vector<T> res(a.begin() + shift, a.end()); return FPS_Base(res); } else { int s = -shift; vector<T> res(s, T(0)); res.insert(res.end(), a.begin(), a.end()); return FPS_Base(res); } } T dot(const FPS_Base &rhs) const { int n = min(a.size(), rhs.a.size()); T res(0); rep(i, 0, n) res = res + a[i] * rhs.a[i]; return res; } FPS_Base derivative() const { if (a.empty()) return FPS_Base(); vector<T> res(a.size() - 1); rep(i, 1, a.size()) { res[i - 1] = a[i] * T(i); } return FPS_Base(res); } FPS_Base integral() const { vector<T> res(a.size() + 1, T(0)); rep(i, 0, a.size()) { res[i + 1] = a[i] / T(i + 1); } return FPS_Base(res); } T eval(const T &x) const { T res(0), cur(1); for (auto coef : a) { res = res + coef * cur; cur = cur * x; } return res; } FPS_Base inv(int deg) const { FPS_Base res({a[0].inv()}); int cur = 1; while (cur < deg) { cur <<= 1; FPS_Base temp = (this->truncate(cur)) * (res * res); res = (res * T(2) - temp).truncate(cur); } return res.truncate(deg); } FPS_Base log(int deg) const { FPS_Base f = this->truncate(deg); FPS_Base f_deriv = f.derivative(); FPS_Base inv_f = f.inv(deg); FPS_Base res = (f_deriv * inv_f).integral().truncate(deg); return res; } FPS_Base exp(int deg) const { if (this->a.empty()) { vector<T> ret(deg, T(0)); if (deg > 0) ret[0] = T(1); return FPS_Base(ret); } FPS_Base res({T(1)}); int cur = 1; while (cur < deg) { cur <<= 1; FPS_Base log_res = res.log(cur); FPS_Base truncated = this->truncate(cur); FPS_Base diff = truncated - log_res; if (diff.a.empty()) diff.a.resize(cur, T(0)); diff.a[0] = diff.a[0] + T(1); res = (res * diff).truncate(cur); } return res.truncate(deg); } FPS_Base pow(long long k, int deg = -1) const { const int n = this->size(); if (deg == -1) deg = n; if (k == 0) { FPS_Base ret(deg); if (deg > 0) ret.a[0] = T(1); return ret; } for (int i = 0; i < n; i++) { if ((*this)[i] != T(0)) { T rev = T(1) / (*this)[i]; FPS_Base tmp = ((*this) * rev).shift(i); FPS_Base ret = (tmp.log(deg) * T(k)).exp(deg); ret = ret * ((*this)[i].pow(k)); ret = ret.shift(-(i * k)); ret = ret.truncate(deg); if (ret.size() < deg) ret.a.resize(deg, T(0)); return ret; } __int128 prod = ((__int128)(i + 1) * k); if (prod >= deg) return FPS_Base(deg); } return FPS_Base(deg); } T &operator[](int i) { if (i >= (int)a.size()) { a.resize(i + 1, T(0)); } return a[i]; } }; // --- NTT フレンドリー版 --- template <typename T> struct FormalPowerSeriesNTTFrendly : public FPS_Base<T, int> { using Base = FPS_Base<T, int>; FormalPowerSeriesNTTFrendly() : Base() {} FormalPowerSeriesNTTFrendly(const vector<T> &a) : Base(a) {} FormalPowerSeriesNTTFrendly(int n) : Base(n) {} }; // --- 自作convolution版 --- template <typename T> struct FormalPowerSeries : public FPS_Base<T, void> { using Base = FPS_Base<T, void>; FormalPowerSeries() : Base() {} FormalPowerSeries(const vector<T> &a) : Base(a) {} FormalPowerSeries(int n) : Base(n) {} }; using fps = FormalPowerSeriesNTTFrendly<modint998244353>; mint f(int n, int m, int tar) { if (tar < 0) return 0; // fps v(tar + 1); // m = min(m, tar + 1); // rep(i, 0, m) { v[i] = 1; } // auto w = v.pow(n, tar + 1); // mint ret = 0; // rep(i, 0, tar + 1) cerr << w[i].val() << ' '; // cerr << endl; // rep(i, 0, tar + 1) cerr << p[i].val() << ' '; // cerr << endl; // rep(i, 0, tar + 1) { ret += w[i]; } mint ret2 = ps[tar + 1]; // rep(i, 0, tar + 1) { ret2 += p[i]; } if (m == tar) ret2 -= n; return ret2; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int n, m; cin >> n >> m; int mx = 2e6; p.resize(mx); ps.resize(mx); rep(i, 0, mx) p[i] = comb(i + n - 1, n - 1).val(); rep(i, 1, mx) ps[i] = ps[i - 1] + p[i - 1]; vector<mint> a(m); mint sm = 0; rep(i, 0, m) { a[i] = f(n, m - i, m - i * n); } mint ans = 0; mint dec = 0; for (int i = m - 1; i >= 0; i--) { ans += i * (a[i] - dec); sm += (a[i] - dec); dec = a[i]; } ans /= sm; cout << ans.val() << endl; }