結果

問題 No.3394 Big Binom
コンテスト
ユーザー umimel
提出日時 2025-12-01 00:47:24
言語 C++14
(gcc 13.3.0 + boost 1.89.0)
結果
AC  
実行時間 47 ms / 2,000 ms
コード長 6,105 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,046 ms
コンパイル使用メモリ 207,952 KB
実行使用メモリ 9,088 KB
最終ジャッジ日時 2025-12-14 19:58:39
合計ジャッジ時間 4,682 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 22
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

// https://judge.yosupo.jp/submission/171341
#define PROBLEM "https://judge.yosupo.jp/problem/factorial"

#include <bits/stdc++.h>
using namespace std;
#include <iostream>
#include <atcoder/modint>

using mint = atcoder::modint998244353;

#include <atcoder/convolution>
#include <cmath>

#include <cassert>
#include <vector>

namespace suisen {
    template <typename T, typename U = T>
    struct factorial {
        factorial() = default;
        factorial(int n) { ensure(n); }

        static void ensure(const int n) {
            int sz = _fac.size();
            if (n + 1 <= sz) return;
            int new_size = std::max(n + 1, sz * 2);
            _fac.resize(new_size), _fac_inv.resize(new_size);
            for (int i = sz; i < new_size; ++i) _fac[i] = _fac[i - 1] * i;
            _fac_inv[new_size - 1] = U(1) / _fac[new_size - 1];
            for (int i = new_size - 1; i > sz; --i) _fac_inv[i - 1] = _fac_inv[i] * i;
        }

        T fac(const int i) {
            ensure(i);
            return _fac[i];
        }
        T operator()(int i) {
            return fac(i);
        }
        U fac_inv(const int i) {
            ensure(i);
            return _fac_inv[i];
        }
        U binom(const int n, const int r) {
            if (n < 0 or r < 0 or n < r) return 0;
            ensure(n);
            return _fac[n] * _fac_inv[r] * _fac_inv[n - r];
        }
        U perm(const int n, const int r) {
            if (n < 0 or r < 0 or n < r) return 0;
            ensure(n);
            return _fac[n] * _fac_inv[n - r];
        }
    private:
        static std::vector<T> _fac;
        static std::vector<U> _fac_inv;
    };
    template <typename T, typename U>
    std::vector<T> factorial<T, U>::_fac{ 1 };
    template <typename T, typename U>
    std::vector<U> factorial<T, U>::_fac_inv{ 1 };
} // namespace suisen

#include <atcoder/convolution>

namespace suisen {
    template <typename mint>
    std::vector<mint> shift_of_sampling_points(const std::vector<mint>& ys, mint t, int m) {
        const int n = ys.size();
        factorial<mint> fac(std::max(n, m));

        std::vector<mint> b = [&] {
            std::vector<mint> f(n), g(n);
            for (int i = 0; i < n; ++i) {
                f[i] = ys[i] * fac.fac_inv(i);
                g[i] = (i & 1 ? -1 : 1) * fac.fac_inv(i);
            }
            std::vector<mint> b = atcoder::convolution(f, g);
            b.resize(n);
            return b;
        }();
        std::vector<mint> e = [&] {
            std::vector<mint> c(n);
            mint prd = 1;
            std::reverse(b.begin(), b.end());
            for (int i = 0; i < n; ++i) {
                b[i] *= fac.fac(n - i - 1);
                c[i] = prd * fac.fac_inv(i);
                prd *= t - i;
            }
            std::vector<mint> e = atcoder::convolution(b, c);
            e.resize(n);
            return e;
        }();
        std::reverse(e.begin(), e.end());
        for (int i = 0; i < n; ++i) {
            e[i] *= fac.fac_inv(i);
        }

        std::vector<mint> f(m);
        for (int i = 0; i < m; ++i) f[i] = fac.fac_inv(i);
        std::vector<mint> res = atcoder::convolution(e, f);
        res.resize(m);
        for (int i = 0; i < m; ++i) res[i] *= fac.fac(i);
        return res;
    }
} // namespace suisen

namespace suisen {
    template <typename mint>
    struct FactorialLarge {
        static constexpr int _p = mint::mod();
        static constexpr int _log_b = 12;
        static constexpr int _b = 1 << _log_b;
        static constexpr int _q = _p >> _log_b;

        FactorialLarge() {
            // f_d(x) := (dx+1) * ... * (dx+d-1)

            // Suppose that we have f_d(0),...,f_d(d-1).
            // f_{2d}(x) = ((2dx+1) * ... * (2dx+d-1)) * (2dx+d) * (((2dx+d)+1) * ... * ((2dx+d)+d-1))
            //           = f_d(2x) * f_d(2x+1) * (2dx+d)

            // We can calculate f_{2d}(0), ..., f_{2d}(2d-1) from f_d(0), f_d(1), ..., f_d(4d-2), f_d(4d-1)
            
            // f_1
            std::vector<mint> f{ 1 };
            for (int i = 0; i < _log_b; ++i) {
                std::vector<mint> g = shift_of_sampling_points<mint>(f, 1 << i, 3 << i);
                const auto get = [&](int j) { return j < (1 << i) ? f[j] : g[j - (1 << i)]; };
                f.resize(2 << i);
                for (int j = 0; j < 2 << i; ++j) {
                    f[j] = get(2 * j) * get(2 * j + 1) * ((2 * j + 1) << i);
                }
            }
            // f_B(x) = (x+1) * ... * (x+B-1)
            if (_q > _b) {
                std::vector<mint> g = shift_of_sampling_points<mint>(f, _b, _q - _b);
                std::move(g.begin(), g.end(), std::back_inserter(f));
            } else {
                f.resize(_q);
            }
            for (int i = 0; i < _q; ++i) {
                f[i] *= mint(i + 1) * _b;
            }
            // f[i] = (i*B + 1) * ... * (i*B + B)
            _acc = std::move(f);

            _acc.insert(_acc.begin(), 1);
            for (int i = 1; i <= _q; ++i) {
                _acc[i] *= _acc[i - 1];
            }
        }

        mint operator()(long long n) {
            if (_p <= n) return 0;
            const int q = n >> _log_b, r = n & (_b - 1);
            // n! = (qb)! * (n-r+1)(n-r+2)...(n)
            mint ans = _acc[q];
            for (int j = 0; j < r; ++j) {
                ans *= mint::raw(n - j);
            }
            return ans;
        }
    private:
        std::vector<mint> _acc;
    };
} // namespace suisen

int main() {
    suisen::FactorialLarge<mint> fact;

    int n, k;
    std::cin >> n >> k;
    int m = 998244353;
    if(n < m){
        mint x = fact(n), y = fact(n-k), z = fact(k);
        mint ans = x*y.inv()*z.inv();

        cout << ans.val() << "\n";
    }else{
        if(m <= k || m <= n-k){
            k = min(k, n-k);
            mint ans = 1;
            for(int i=0; i<k; i++) if(n-i != m) ans *= mint(n-i);
            for(int i=1; i<=k; i++) ans *= mint(i).inv();

            cout << ans.val() << "\n";
        }else{
            cout << 0 << "\n";
        }
    }
}
0