結果

問題 No.3160 Party Game
ユーザー SnowBeenDiding
提出日時 2025-05-23 20:33:38
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
RE  
実行時間 -
コード長 8,604 bytes
コンパイル時間 7,207 ms
コンパイル使用メモリ 333,828 KB
実行使用メモリ 61,964 KB
最終ジャッジ日時 2025-05-27 22:03:27
合計ジャッジ時間 15,528 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 7 RE * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <atcoder/all>
#include <bits/stdc++.h>
#define rep(i, a, b) for (ll i = (ll)(a); i < (ll)(b); i++)
using namespace atcoder;
using namespace std;

typedef long long ll;
using mint = modint998244353;

struct Comb {
    vector<mint> fact, ifact;
    int MAX_COM;
    Comb() {}
    Comb(int n, int mod) {
        MAX_COM = n;
        init(mod, MAX_COM);
    }
    void init(long long MOD, long long MAX_COM) {
        int n = MAX_COM;
        assert(n < MOD);
        fact = vector<mint>(n + 1);
        ifact = vector<mint>(n + 1);
        fact[0] = 1;
        for (int i = 1; i <= n; ++i)
            fact[i] = fact[i - 1] * i;
        ifact[n] = fact[n].inv();
        for (int i = n; i >= 1; --i)
            ifact[i - 1] = ifact[i] * i;
    }
    mint operator()(long long n, long long k) {
        if (k < 0 || k > n)
            return 0;
        return fact[n] * ifact[k] * ifact[n - k];
    }
};
Comb comb(5000010, 998244353);

vector<mint> p, ps;

template <typename T, class Conv = void> struct FPS_Base {
    vector<T> a;

    FPS_Base() {}
    FPS_Base(const vector<T> &a_) : a(a_) { normalize(); }
    FPS_Base(int n) { a.assign(n, T(0)); }

    void normalize() {
        while (!a.empty() && a.back() == T(0))
            a.pop_back();
    }

    int size() const { return a.size(); }
    T operator[](int i) const { return (i < a.size() ? a[i] : T(0)); }

    FPS_Base operator+(const FPS_Base &rhs) const {
        int n = max(a.size(), rhs.a.size());
        vector<T> res(n, T(0));
        rep(i, 0, n) res[i] = (*this)[i] + rhs[i];
        return FPS_Base(res);
    }

    FPS_Base operator-(const FPS_Base &rhs) const {
        int n = max(a.size(), rhs.a.size());
        vector<T> res(n, T(0));
        rep(i, 0, n) res[i] = (*this)[i] - rhs[i];
        return FPS_Base(res);
    }

    FPS_Base operator*(const FPS_Base &rhs) const {
        vector<T> res;
        if constexpr (is_same<Conv, void>::value) {
            // res = convolution1000000007(a, rhs.a);
        } else {
            res = convolution(a, rhs.a);
        }
        return FPS_Base(res);
    }

    FPS_Base operator*(const T &c) const {
        vector<T> res = a;
        for (auto &x : res)
            x = x * c;
        return FPS_Base(res);
    }

    FPS_Base truncate(int deg) const {
        vector<T> res(min((int)a.size(), deg));
        rep(i, 0, res.size()) res[i] = a[i];
        return FPS_Base(res);
    }

    FPS_Base rev(int deg = -1) const {
        vector<T> res = a;
        reverse(res.begin(), res.end());
        if (deg != -1) {
            if ((int)res.size() < deg)
                res.resize(deg, T(0));
            else
                res = vector<T>(res.begin(), res.begin() + deg);
        }
        return FPS_Base(res);
    }

    pair<FPS_Base, FPS_Base> divmod(const FPS_Base &rhs) const {
        int n = a.size(), m = rhs.a.size();
        if (n < m)
            return {FPS_Base({T(0)}), *this};
        FPS_Base rev_f = this->rev();
        FPS_Base rev_g = rhs.rev();
        FPS_Base q =
            (rev_f.inv(n - m + 1) * rev_g).truncate(n - m + 1).rev(n - m + 1);
        FPS_Base r = (*this) - (rhs * q);
        r.normalize();
        return {q, r};
    }

    FPS_Base operator/(const FPS_Base &rhs) const { return divmod(rhs).first; }

    FPS_Base operator%(const FPS_Base &rhs) const { return divmod(rhs).second; }

    FPS_Base shift(int shift) const {
        if (shift >= 0) {
            if ((int)a.size() <= shift)
                return FPS_Base();
            vector<T> res(a.begin() + shift, a.end());
            return FPS_Base(res);
        } else {
            int s = -shift;
            vector<T> res(s, T(0));
            res.insert(res.end(), a.begin(), a.end());
            return FPS_Base(res);
        }
    }

    T dot(const FPS_Base &rhs) const {
        int n = min(a.size(), rhs.a.size());
        T res(0);
        rep(i, 0, n) res = res + a[i] * rhs.a[i];
        return res;
    }

    FPS_Base derivative() const {
        if (a.empty())
            return FPS_Base();
        vector<T> res(a.size() - 1);
        rep(i, 1, a.size()) { res[i - 1] = a[i] * T(i); }
        return FPS_Base(res);
    }

    FPS_Base integral() const {
        vector<T> res(a.size() + 1, T(0));
        rep(i, 0, a.size()) { res[i + 1] = a[i] / T(i + 1); }
        return FPS_Base(res);
    }

    T eval(const T &x) const {
        T res(0), cur(1);
        for (auto coef : a) {
            res = res + coef * cur;
            cur = cur * x;
        }
        return res;
    }

    FPS_Base inv(int deg) const {
        FPS_Base res({a[0].inv()});
        int cur = 1;
        while (cur < deg) {
            cur <<= 1;
            FPS_Base temp = (this->truncate(cur)) * (res * res);
            res = (res * T(2) - temp).truncate(cur);
        }
        return res.truncate(deg);
    }

    FPS_Base log(int deg) const {
        FPS_Base f = this->truncate(deg);
        FPS_Base f_deriv = f.derivative();
        FPS_Base inv_f = f.inv(deg);
        FPS_Base res = (f_deriv * inv_f).integral().truncate(deg);
        return res;
    }

    FPS_Base exp(int deg) const {
        if (this->a.empty()) {
            vector<T> ret(deg, T(0));
            if (deg > 0)
                ret[0] = T(1);
            return FPS_Base(ret);
        }
        FPS_Base res({T(1)});
        int cur = 1;
        while (cur < deg) {
            cur <<= 1;
            FPS_Base log_res = res.log(cur);
            FPS_Base truncated = this->truncate(cur);
            FPS_Base diff = truncated - log_res;
            if (diff.a.empty())
                diff.a.resize(cur, T(0));
            diff.a[0] = diff.a[0] + T(1);
            res = (res * diff).truncate(cur);
        }
        return res.truncate(deg);
    }

    FPS_Base pow(long long k, int deg = -1) const {
        const int n = this->size();
        if (deg == -1)
            deg = n;
        if (k == 0) {
            FPS_Base ret(deg);
            if (deg > 0)
                ret.a[0] = T(1);
            return ret;
        }
        for (int i = 0; i < n; i++) {
            if ((*this)[i] != T(0)) {
                T rev = T(1) / (*this)[i];
                FPS_Base tmp = ((*this) * rev).shift(i);
                FPS_Base ret = (tmp.log(deg) * T(k)).exp(deg);
                ret = ret * ((*this)[i].pow(k));
                ret = ret.shift(-(i * k));
                ret = ret.truncate(deg);
                if (ret.size() < deg)
                    ret.a.resize(deg, T(0));
                return ret;
            }
            __int128 prod = ((__int128)(i + 1) * k);
            if (prod >= deg)
                return FPS_Base(deg);
        }
        return FPS_Base(deg);
    }
    T &operator[](int i) {
        if (i >= (int)a.size()) {
            a.resize(i + 1, T(0));
        }
        return a[i];
    }
};

// --- NTT フレンドリー版 ---
template <typename T>
struct FormalPowerSeriesNTTFrendly : public FPS_Base<T, int> {
    using Base = FPS_Base<T, int>;
    FormalPowerSeriesNTTFrendly() : Base() {}
    FormalPowerSeriesNTTFrendly(const vector<T> &a) : Base(a) {}
    FormalPowerSeriesNTTFrendly(int n) : Base(n) {}
};

// --- 自作convolution版 ---
template <typename T> struct FormalPowerSeries : public FPS_Base<T, void> {
    using Base = FPS_Base<T, void>;
    FormalPowerSeries() : Base() {}
    FormalPowerSeries(const vector<T> &a) : Base(a) {}
    FormalPowerSeries(int n) : Base(n) {}
};

using fps = FormalPowerSeriesNTTFrendly<modint998244353>;

mint f(int n, int m, int tar) {
    if (tar < 0)
        return 0;
    // fps v(tar + 1);
    // m = min(m, tar + 1);
    // rep(i, 0, m) { v[i] = 1; }
    // auto w = v.pow(n, tar + 1);
    // mint ret = 0;
    // rep(i, 0, tar + 1) cerr << w[i].val() << ' ';
    // cerr << endl;
    // rep(i, 0, tar + 1) cerr << p[i].val() << ' ';
    // cerr << endl;
    // rep(i, 0, tar + 1) { ret += w[i]; }

    mint ret2 = ps[tar + 1];
    // rep(i, 0, tar + 1) { ret2 += p[i]; }
    if (m == tar)
        ret2 -= n;
    return ret2;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    int mx = 2e6;
    p.resize(mx);
    ps.resize(mx);
    rep(i, 0, mx) p[i] = comb(i + n - 1, n - 1).val();
    rep(i, 1, mx) ps[i] = ps[i - 1] + p[i - 1];
    vector<mint> a(m);
    mint sm = 0;
    rep(i, 0, m) { a[i] = f(n, m - i, m - i * n); }
    mint ans = 0;
    mint dec = 0;
    for (int i = m - 1; i >= 0; i--) {
        ans += i * (a[i] - dec);
        sm += (a[i] - dec);
        dec = a[i];
    }
    ans /= sm;
    cout << ans.val() << endl;
}
0