結果

問題 No.3394 Big Binom
コンテスト
ユーザー glass_256
提出日時 2025-12-01 22:52:34
言語 C++23
(gcc 13.3.0 + boost 1.89.0)
結果
AC  
実行時間 227 ms / 2,000 ms
コード長 9,872 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 5,134 ms
コンパイル使用メモリ 326,252 KB
実行使用メモリ 51,268 KB
最終ジャッジ日時 2025-12-14 20:01:43
合計ジャッジ時間 8,397 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 22
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

//https://suisen-kyopro.hatenablog.com/entry/2023/11/22/201600 よりお借りした


#include <atcoder/modint>
#include <atcoder/convolution>

template <typename mint,
    std::enable_if_t<atcoder::internal::is_static_modint<mint>::value, std::nullptr_t> = nullptr>
std::vector<mint> arbitrary_mod_convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
    int n = int(a.size()), m = int(b.size());

    {
        // check if the mod is ntt-friendly
        int maxz = 1;
        while (not ((mint::mod() - 1) & maxz)) {
            maxz <<= 1;
        }
        int z = 1;
        while (z < n + m - 1) {
            z <<= 1;
        }
        if (z <= maxz) {
            return atcoder::convolution<mint>(a, b);
        }
    }

    if (n == 0 or m == 0) return {};
    //if (std::min(n, m) <= 120) return atcoder::internal::convolution_naive(a, b);

    static constexpr long long MOD1 = 754974721;  // 2^24
    static constexpr long long MOD2 = 167772161;  // 2^25
    static constexpr long long MOD3 = 469762049;  // 2^26
    static constexpr long long M1M2 = MOD1 * MOD2;
    static constexpr long long INV_M1_MOD2 = atcoder::internal::inv_gcd(MOD1, MOD2).second;
    static constexpr long long INV_M1M2_MOD3 = atcoder::internal::inv_gcd(M1M2, MOD3).second;

    std::vector<int> a2(n), b2(m);
    for (int i = 0; i < n; ++i) a2[i] = a[i].val();
    for (int i = 0; i < m; ++i) b2[i] = b[i].val();

    auto c1 = atcoder::convolution<MOD1>(a2, b2);
    auto c2 = atcoder::convolution<MOD2>(a2, b2);
    auto c3 = atcoder::convolution<MOD3>(a2, b2);

    const long long m1m2 = mint(M1M2).val();
    std::vector<mint> c(n + m - 1);
    for (int i = 0; i < n + m - 1; ++i) {
        // Garner's Algorithm
        // X = x1 + x2 * m1 + x3 * m1 * m2
        // x1 = c1[i], x2 = (c2[i] - x1) / m1 (mod m2), x3 = (c3[i] - x1 - x2 * m1) / m2 (mod m3)
        long long x1 = c1[i];
        long long x2 = (atcoder::static_modint<MOD2>(c2[i] - x1) * INV_M1_MOD2).val();
        long long x3 = (atcoder::static_modint<MOD3>(c3[i] - x1 - x2 * MOD1) * INV_M1M2_MOD3).val();
        c[i] = x1 + x2 * MOD1 + x3 * m1m2;
    }
    return c;
}

template <typename mint,
    std::enable_if_t<atcoder::internal::is_static_modint<mint>::value, std::nullptr_t> = nullptr>
struct factorial {
    using value_type = mint;

    factorial() = delete;

    static value_type fact(int n) {
        ensure(n + 1);
        return _fact[n];
    }
    static value_type inv_fact(int n) {
        ensure(n + 1);
        return _inv_fact[n];
    }
    static value_type binom(int n, int r) {
        if (r < 0 or r > n) return 0;
        return fact(n) * inv_fact(r) * inv_fact(n - r);
    }
    static value_type perm(int n, int r) {
        if (r < 0 or r > n) return 0;
        return fact(n) * inv_fact(n - r);
    }

    static void ensure(int size) {
        const int curr_size = _fact.size();
        if (size <= curr_size) return;
        const int next_size = std::max(curr_size * 2, size);
        _fact.resize(next_size);
        _inv_fact.resize(next_size);
        for (int i = curr_size; i < next_size; ++i) {
            _fact[i] = _fact[i - 1] * i;
        }
        _inv_fact.back() = _fact.back().inv();
        for (int i = next_size - 1; i > curr_size; --i) {
            _inv_fact[i - 1] = _inv_fact[i] * i;
        }
    }
private:
    static inline std::vector<value_type> _fact{ 1 }, _inv_fact{ 1 };
};

/**
 * Computes f(t),f(t+1),...,f(t+m-1) from f(0),f(1),...,f(n-1)
 */
template <typename mint, typename Convolve,
    std::enable_if_t<std::conjunction_v<
    atcoder::internal::is_static_modint<mint>,
    std::is_invocable_r<std::vector<mint>, Convolve, std::vector<mint>, std::vector<mint>>
>, std::nullptr_t> = nullptr>
std::vector<mint> shift_of_sampling_points(const std::vector<mint>& ys, mint t, int m, const Convolve& convolve) {
    const int n = ys.size();
    factorial<mint>::ensure(std::max(n, m));
    std::vector<mint> b = [&] {
        std::vector<mint> f(n), g(n);
        for (int i = 0; i < n; ++i) {
            f[i] = ys[i] * factorial<mint>::inv_fact(i);
            g[i] = (i & 1 ? -1 : 1) * factorial<mint>::inv_fact(i);
        }
        std::vector<mint> b = convolve(f, g);
        b.resize(n);
        return b;
        }();
    std::vector<mint> e = [&] {
        std::vector<mint> c(n);
        mint prd = 1;
        std::reverse(b.begin(), b.end());
        for (int i = 0; i < n; ++i) {
            b[i] *= factorial<mint>::fact(n - i - 1);
            c[i] = prd * factorial<mint>::inv_fact(i);
            prd *= t - i;
        }
        std::vector<mint> e = convolve(b, c);
        e.resize(n);
        return e;
        }();
    std::reverse(e.begin(), e.end());
    for (int i = 0; i < n; ++i) {
        e[i] *= factorial<mint>::inv_fact(i);
    }
    std::vector<mint> f(m);
    for (int i = 0; i < m; ++i) {
        f[i] = factorial<mint>::inv_fact(i);
    }
    std::vector<mint> res = convolve(e, f);
    res.resize(m);
    for (int i = 0; i < m; ++i) {
        res[i] *= factorial<mint>::fact(i);
    }
    return res;
}

/**
 * Computes f(t),f(t+1),...,f(t+m-1) from f(0),f(1),...,f(n-1)
 */
template <typename mint,
    std::enable_if_t<atcoder::internal::is_static_modint<mint>::value, std::nullptr_t> = nullptr>
std::vector<mint> shift_of_sampling_points(const std::vector<mint>& ys, mint t, int m) {
    const auto convolve = [](const std::vector<mint>& a, const std::vector<mint>& b) { return atcoder::convolution(a, b); };
    return shift_of_sampling_points(ys, t, m, convolve);
}

template <typename mint,
    std::enable_if_t<atcoder::internal::is_static_modint<mint>::value, std::nullptr_t> = nullptr>
struct factorial_large {
    using value_type = mint;

    static constexpr int LOG_BLOCK_SIZE = 9;
    static constexpr int BLOCK_SIZE = 1 << LOG_BLOCK_SIZE;
    static constexpr int BLOCK_NUM = value_type::mod() >> LOG_BLOCK_SIZE;

    static inline int threshold = 2000000;

    factorial_large() = delete;

    static value_type fact(int n) {
        return n <= threshold ? factorial<mint>::fact(n) : _large_fact(n);
    }
    static value_type inv_fact(int n) {
        return n <= threshold ? factorial<mint>::inv_fact(n) : _large_fact(n).inv();
    }
    static value_type binom(int n, int r) {
        if (r < 0 or r > n) return 0;
        return fact(n) * inv_fact(r) * inv_fact(n - r);
    }
    static value_type perm(int n, int r) {
        if (r < 0 or r > n) return 0;
        return fact(n) * inv_fact(n - r);
    }
private:
    static inline std::vector<value_type> _block_fact{};

    static void _build() {
        if (_block_fact.size()) {
            return;
        }
        std::vector<value_type> f{ 1 };
        f.reserve(BLOCK_SIZE);
        for (int i = 0; i < LOG_BLOCK_SIZE; ++i) {
            std::vector<value_type> g = shift_of_sampling_points<value_type>(f, 1 << i, 3 << i, arbitrary_mod_convolution<value_type>);
            const auto get = [&](int j) { return j < (1 << i) ? f[j] : g[j - (1 << i)]; };
            f.resize(2 << i);
            for (int j = 0; j < 2 << i; ++j) {
                f[j] = get(2 * j) * get(2 * j + 1) * ((2 * j + 1) << i);
            }
        }
        // f_B(x) = (x+1) * ... * (x+B-1)
        if (BLOCK_NUM > BLOCK_SIZE) {
            std::vector<value_type> g = shift_of_sampling_points<value_type>(f, BLOCK_SIZE, BLOCK_NUM - BLOCK_SIZE, arbitrary_mod_convolution<value_type>);
            std::move(g.begin(), g.end(), std::back_inserter(f));
        }
        else {
            f.resize(BLOCK_NUM);
        }
        for (int i = 0; i < BLOCK_NUM; ++i) {
            f[i] *= value_type(i + 1) * BLOCK_SIZE;
        }
        // f[i] = (i*B + 1) * ... * (i*B + B)

        f.insert(f.begin(), 1);
        for (int i = 1; i <= BLOCK_NUM; ++i) {
            f[i] *= f[i - 1];
        }
        _block_fact = std::move(f);
    }

    static value_type _large_fact(int n) {
        _build();
        value_type res;
        int q = n / BLOCK_SIZE, r = n % BLOCK_SIZE;
        if (2 * r <= BLOCK_SIZE) {
            res = _block_fact[q];
            for (int i = 0; i < r; ++i) {
                res *= value_type::raw(n - i);
            }
        }
        else if (q != factorial_large::BLOCK_NUM) {
            res = _block_fact[q + 1];
            value_type den = 1;
            for (int i = 1; i <= BLOCK_SIZE - r; ++i) {
                den *= value_type::raw(n + i);
            }
            res /= den;
        }
        else {
            // Wilson's theorem
            res = value_type::mod() - 1;
            value_type den = 1;
            for (int i = value_type::mod() - 1; i > n; --i) {
                den *= value_type::raw(i);
            }
            res /= den;
        }
        return res;
    }
};

#include <bits/stdc++.h>
using namespace std;
using namespace atcoder;

using mint = atcoder::modint998244353;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    //int t;
    //std::cin >> t;

    //while (t--) {
    //    int n;
    //    std::cin >> n;

    //    std::cout << factorial_large<mint>::fact(n).val() << '\n';
    //}

    int N, K;
    cin >> N >> K;

    K = min(K, N - K);

    if (N - K >= 998244353) {
        modint ans = 1;
        for (int i = N - K + 1;i <= N; i++) {
            ans *= i;
        }
        modint div = 1;
        for (int i = 1;i <= K;i++) {
            div *= i;
        }
        ans *= div.inv();

        cout << ans.val() << endl;
    }
    else if (N >= 998244353 && N - K <= 998244353) {
        cout << 0 << endl;
    }
    else {
        modint a = factorial_large<mint>::fact(int(N)).val();
        modint b = factorial_large<mint>::fact(int(K)).val();
        modint c = factorial_large<mint>::fact(int(N - K)).val();

        b *= c;
        a *= b.inv();

        cout << a.val() << endl;
    }

    return 0;
}
0