結果

問題 No.3394 Big Binom
コンテスト
ユーザー 回転
提出日時 2026-04-24 10:38:52
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
WA  
実行時間 -
コード長 4,874 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 2,670 ms
コンパイル使用メモリ 234,420 KB
実行使用メモリ 6,400 KB
最終ジャッジ日時 2026-04-24 10:38:58
合計ジャッジ時間 5,071 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 9 WA * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

// Generated by Gemini3.1 Pro
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <atcoder/convolution>
#include <atcoder/modint>

using namespace std;
using mint = atcoder::modint998244353;
const long long MOD = 998244353;

vector<mint> fact, invFact;

// 階乗と逆元の事前計算 (O(N))
void initFact(int N) {
    fact.assign(N + 1, 1);
    invFact.assign(N + 1, 1);
    for (int i = 1; i <= N; i++) fact[i] = fact[i - 1] * i;
    invFact[N] = fact[N].inv();
    for (int i = N - 1; i >= 0; i--) invFact[i] = invFact[i + 1] * (i + 1);
}

// 多項式の評価点を c だけシフトする関数
vector<mint> shift(const vector<mint>& V, mint c, int m) {
    int d = V.size() - 1;
    vector<mint> A(d + 1), B(d + m + 1);
    
    for (int i = 0; i <= d; i++) {
        A[i] = V[i] * invFact[i] * invFact[d - i];
        if ((d - i) % 2 == 1) A[i] = -A[i];
    }
    
    for (int j = 0; j <= d + m; j++) {
        mint val = c + j - d;
        if (val.val() == 0) B[j] = 0;
        else B[j] = val.inv();
    }
    
    // NTTによる畳み込み O(d log d)
    vector<mint> C = atcoder::convolution(A, B);
    vector<mint> res(m + 1);
    
    mint current_mult = 1;
    for (int j = 0; j <= d; j++) current_mult *= (c - j);
    
    for (int k = 0; k <= m; k++) {
        long long eval_pt = (c.val() + k) % MOD;
        // 評価点が元の [0, d] の範囲に含まれる場合は V から直接取得
        if (eval_pt <= d) {
            res[k] = V[eval_pt];
        } else {
            res[k] = C[k + d] * current_mult;
        }
        
        // 次のステップのための乗数更新
        if (k < m) {
            mint denom = c + k - d;
            if (denom.val() == 0) {
                current_mult = 1;
                for (int j = 0; j <= d; j++) current_mult *= (c + k + 1 - j);
            } else {
                current_mult *= (c + k + 1);
                current_mult *= denom.inv();
            }
        }
    }
    return res;
}

// O(√n log n) で n! mod P を計算する関数
mint fast_fact(long long n) {
    if (n >= MOD) return 0;
    
    long long v = sqrt(n);
    int d = 1;
    vector<mint> V = {0, v};
    int msb = 0;
    while ((1LL << (msb + 1)) <= v) msb++;

    // ダブリングにより次数 v の評価値を求める
    for (int step = msb - 1; step >= 0; step--) {
        vector<mint> V2 = shift(V, d + 1, d - 1);
        vector<mint> V_full = V;
        for (auto x : V2) V_full.push_back(x);
        
        vector<mint> H = shift(V, mint(d) * mint(v).inv(), 2 * d);
        
        vector<mint> next_V(2 * d + 1);
        for (int i = 0; i <= 2 * d; i++) {
            next_V[i] = V_full[i] * H[i];
        }
        
        d = 2 * d;
        V = next_V;
        
        // 奇数次数の場合は +1 補正
        if ((v >> step) & 1) {
            vector<mint> V_next(d + 2);
            for (int k = 0; k <= d; k++) {
                V_next[k] = V[k] * (mint(k) * v + d);
            }
            vector<mint> V_d1 = shift(V, d + 1, 0);
            V_next[d + 1] = V_d1[0] * (mint(d + 1) * v + d);
            
            d = d + 1;
            V = V_next;
        }
    }
    
    // まとめて階乗を計算
    mint fact_n = 1;
    for (int k = 1; k <= v; k++) {
        fact_n *= V[k];
    }
    // 端数を掛ける (最大でも √n 回程度)
    for (long long i = (long long)v * v + 1; i <= n; i++) {
        fact_n *= i;
    }
    return fact_n;
}

int main() {
    // 入出力の高速化
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    long long N, K;
    if (!(cin >> N >> K)) return 0;

    if (K < 0 || K > N) {
        cout << 0 << "\n";
        return 0;
    }

    // リュカの定理の準備
    long long n0 = N % MOD, n1 = N / MOD;
    long long k0 = K % MOD, k1 = K / MOD;

    if (k1 > n1) {
        cout << 0 << "\n";
        return 0;
    }
    
    mint ans1 = 1; // 制約上 n1 <= 1 なので、k1 <= n1 ならば組合せは 1 通り
    
    if (k0 > n0) {
        cout << 0 << "\n";
        return 0;
    }

    mint ans0 = 0;
    if (k0 == 0 || k0 == n0) {
        ans0 = 1;
    } else if (min(k0, n0 - k0) <= 2000000) {
        // 必要計算量が小さい場合は軽量なO(K)ループを使用 (計算負荷の削減)
        long long r = min(k0, n0 - k0);
        mint num = 1, den = 1;
        for (long long i = 1; i <= r; i++) {
            num *= (n0 - i + 1);
            den *= i;
        }
        ans0 = num / den;
    } else {
        // N, Kが巨大な場合は NTTベースの高速階乗アルゴリズムを使用
        initFact(100000); 
        mint fn = fast_fact(n0);
        mint fk = fast_fact(k0);
        mint fnk = fast_fact(n0 - k0);
        ans0 = fn / (fk * fnk);
    }

    mint final_ans = ans1 * ans0;
    cout << final_ans.val() << "\n";

    return 0;
}
0