結果

問題 No.2141 Enumeratest
ユーザー kanra824kanra824
提出日時 2022-12-02 22:30:42
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 5,778 bytes
コンパイル時間 1,686 ms
コンパイル使用メモリ 172,020 KB
実行使用メモリ 14,848 KB
最終ジャッジ日時 2024-10-10 00:02:27
合計ジャッジ時間 3,315 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 8 ms
6,020 KB
testcase_03 AC 2 ms
5,248 KB
testcase_04 AC 36 ms
14,848 KB
testcase_05 WA -
testcase_06 AC 2 ms
5,248 KB
testcase_07 AC 4 ms
5,248 KB
testcase_08 AC 10 ms
7,296 KB
testcase_09 AC 8 ms
5,892 KB
testcase_10 AC 15 ms
9,856 KB
testcase_11 AC 20 ms
13,708 KB
testcase_12 AC 3 ms
5,248 KB
testcase_13 AC 2 ms
5,248 KB
testcase_14 AC 20 ms
13,588 KB
testcase_15 AC 14 ms
9,216 KB
testcase_16 AC 16 ms
10,880 KB
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 AC 23 ms
14,336 KB
testcase_25 AC 10 ms
7,040 KB
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 AC 26 ms
12,176 KB
testcase_35 WA -
testcase_36 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>

using namespace std;

#define REP(i, n) for (int i = 0; i < (n); ++i)
#define RREP(i, n) for (int i = (n); i >= 0; --i)
#define FOR(i, a, n) for (int i = (a); i < (n); ++i)
#define RFOR(i, a, b) for (int i = (a); i >= (b); --i)

#define SZ(x) ((int)(x).size())
#define ALL(x) (x).begin(), (x).end()

template <class T>
ostream& operator<<(ostream& os, const vector<T>& v) {
    REP(i, SZ(v)) {
        if (i) os << " ";
        os << v[i];
    }
    return os;
}

template <class T, class U>
ostream& operator<<(ostream& os, const pair<T, U>& p) {
    os << p.first << " " << p.second;
    return os;
}

template <class T>
bool chmax(T& a, const T& b) {
    bool res = a < b;
    if (a < b) a = b;
    return res;
}

template <class T>
bool chmin(T& a, const T& b) {
    bool res = a > b;
    if (a > b) a = b;
    return res;
}

using ll = long long;
using ull = unsigned long long;
using ld = long double;
using P = pair<int, int>;
using PLL = pair<ll, ll>;
using vi = vector<int>;
using vll = vector<ll>;
using vvi = vector<vi>;
using vvll = vector<vll>;

const ll MOD = 1e9 + 7;
const ll MOD998 = 998244353;
const int INF = INT_MAX;
const ll LINF = LLONG_MAX;
const int inf = INT_MIN;
const ll linf = LLONG_MIN;
const ld eps = 1e-9;

template<int m>
struct mint {
    int x;
    mint(ll x = 0) : x(((x % m) + m) % m) {}
    mint operator-() const { return x ? m-x : 0; }
    mint &operator+=(mint r) {
        if ((x += r.x) >= m) x -= m;
        return *this;
    }
    mint &operator-=(mint r) {
        if ((x -= r.x) < 0) x += m;
        return *this;
    }
    mint &operator*=(mint r) {
        x = ((ll)x * r.x) % m;
        return *this;
    }
    mint inv() const { return pow(m-2); }
    mint &operator/=(mint r) { return *this *= r.inv(); }

    friend mint operator+(mint l, mint r) { return l += r; }
    friend mint operator-(mint l, mint r) { return l -= r; }
    friend mint operator*(mint l, mint r) { return l *= r; }
    friend mint operator/(mint l, mint r) { return l /= r; }
    mint pow(ll n) const {
        mint ret = 1, tmp = *this;
        while (n) {
            if (n & 1) ret *= tmp;
            tmp *= tmp, n >>= 1;
        }
        return ret;
    }
    friend bool operator==(mint l, mint r) { return l.x == r.x; }
    friend bool operator!=(mint l, mint r) { return l.x != r.x; }
    friend ostream &operator<<(ostream &os, mint a) {
        return os << a.x;
    }
    friend istream &operator>>(istream &is, mint& a) {
        ll x; is >> x; a = x; return is;
    }
};

using Int = mint<MOD998>;

template<typename T>
struct Combination {
    int _n = 1;
    vector<T> _fact{1}, _rfact{1};

    void extend(int n) {
        if (n <= _n) return;
        _fact.resize(n);
        _rfact.resize(n);
        for (int i = _n; i < n; ++i) _fact[i] = _fact[i - 1] * i;
        _rfact[n - 1] = 1 / _fact[n - 1];
        for (int i = n - 1; i > _n; --i) _rfact[i - 1] = _rfact[i] * i;
        _n = n;
    }

    T fact(int k) {
        extend(k + 1);
        return _fact.at(k);
    }

    T rfact(int k) {
        extend(k + 1);
        return _rfact.at(k);
    }

    T P(int n, int r) {
        if (r < 0 or n < r) return 0;
        return fact(n) * rfact(n - r);
    }

    T C(int n, int r) {
        if (r < 0 or n < r) return 0;
        return fact(n) * rfact(r) * rfact(n - r);
    }

    T H(int n, int r) {
        return (n == 0 and r == 0) ? 1 : C(n + r - 1, r);
    }

    // O(k logn)
    // スターリング数
    // Stirling(n, k) := n 個の区別できるボールを k 個の区別できない箱にいれる場合の数
    //                   それぞれの箱には1個以上ボールをいれる
    // ---
    // S(n, k) = S(n-1, k-1) + k * S(n-1, k)
    // * 特定の1個だけで1個の箱にいれる場合は S(n-1, k-1)
    // * そうでない場合は S(n-1, k) 通りに対して特定の1個をいれるのが k 通り
    // ---
    // 各グループにつきr個以上, の制限がある場合
    // S(n, k) = C(n-1, r-1) * S(n-r, k-1) + k * S(n-1, k)
    // ---
    // 玉がn個あるうちのいくつかを選んでkグループに分ける場合
    // S(n, k) = S(n-1, k-1) + (k+1) * S(n-1, k)
    T Stirling(ll n, int k) {
        T ret = 0;
        for (int l = 0; l <= k; ++l) {
            ret += C(k, l) * T{k-l}.pow(n) * (l & 1 ? -1 : 1);
        }
        return ret / fact(k);
    }

    // O(k^2 logn)
    // ベル数
    // Bell(n, k) := n 個の区別できるボールを k 個の区別できない箱にいれる場合の数
    // ---
    // B(n+1) := Bell(n+1, n+1) = n 個の区別できるボールの分割の総数
    // B(n+1) = \sum_{i=0}^n C(n,i) * B(i)
    // * 特定の1個が属するグループに, 他のボールがn-i 個入っているとき,
    // * 残りi 個の並べ方はB(i)
    T Bell(ll n, int k) {
        T ret = 0;
        for (int l = 0; l <= k; ++l) {
            ret += Stirling(n, l);
        }
        return ret;
    }

};

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    cout << fixed << setprecision(10);

    int n, m; cin >> n >> m;
    Combination<Int> comb;

    Int ans = 1;
    if(n > m) {
        ans = comb.P(n, m);
    } else {
        int val = m / n;
        int cnt = m % n;
        int now = m;
        REP(i, n-1) {
            if(i < cnt) {
                ans *= comb.C(now, m / n + 1);
                now -= m / n + 1;
            } else {
                ans *= comb.C(now, m / n);
                now -= m / n;
            }
        }
    }
    cout << ans << endl;
}
0