結果
| 問題 | No.3394 Big Binom |
| コンテスト | |
| ユーザー |
回転
|
| 提出日時 | 2026-04-24 10:38:52 |
| 言語 | C++23 (gcc 15.2.0 + boost 1.89.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 4,874 bytes |
| 記録 | |
| コンパイル時間 | 2,670 ms |
| コンパイル使用メモリ | 234,420 KB |
| 実行使用メモリ | 6,400 KB |
| 最終ジャッジ日時 | 2026-04-24 10:38:58 |
| 合計ジャッジ時間 | 5,071 ms |
|
ジャッジサーバーID (参考情報) |
judge1_1 / judge2_1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 9 WA * 13 |
ソースコード
// Generated by Gemini3.1 Pro
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <atcoder/convolution>
#include <atcoder/modint>
using namespace std;
using mint = atcoder::modint998244353;
const long long MOD = 998244353;
vector<mint> fact, invFact;
// 階乗と逆元の事前計算 (O(N))
void initFact(int N) {
fact.assign(N + 1, 1);
invFact.assign(N + 1, 1);
for (int i = 1; i <= N; i++) fact[i] = fact[i - 1] * i;
invFact[N] = fact[N].inv();
for (int i = N - 1; i >= 0; i--) invFact[i] = invFact[i + 1] * (i + 1);
}
// 多項式の評価点を c だけシフトする関数
vector<mint> shift(const vector<mint>& V, mint c, int m) {
int d = V.size() - 1;
vector<mint> A(d + 1), B(d + m + 1);
for (int i = 0; i <= d; i++) {
A[i] = V[i] * invFact[i] * invFact[d - i];
if ((d - i) % 2 == 1) A[i] = -A[i];
}
for (int j = 0; j <= d + m; j++) {
mint val = c + j - d;
if (val.val() == 0) B[j] = 0;
else B[j] = val.inv();
}
// NTTによる畳み込み O(d log d)
vector<mint> C = atcoder::convolution(A, B);
vector<mint> res(m + 1);
mint current_mult = 1;
for (int j = 0; j <= d; j++) current_mult *= (c - j);
for (int k = 0; k <= m; k++) {
long long eval_pt = (c.val() + k) % MOD;
// 評価点が元の [0, d] の範囲に含まれる場合は V から直接取得
if (eval_pt <= d) {
res[k] = V[eval_pt];
} else {
res[k] = C[k + d] * current_mult;
}
// 次のステップのための乗数更新
if (k < m) {
mint denom = c + k - d;
if (denom.val() == 0) {
current_mult = 1;
for (int j = 0; j <= d; j++) current_mult *= (c + k + 1 - j);
} else {
current_mult *= (c + k + 1);
current_mult *= denom.inv();
}
}
}
return res;
}
// O(√n log n) で n! mod P を計算する関数
mint fast_fact(long long n) {
if (n >= MOD) return 0;
long long v = sqrt(n);
int d = 1;
vector<mint> V = {0, v};
int msb = 0;
while ((1LL << (msb + 1)) <= v) msb++;
// ダブリングにより次数 v の評価値を求める
for (int step = msb - 1; step >= 0; step--) {
vector<mint> V2 = shift(V, d + 1, d - 1);
vector<mint> V_full = V;
for (auto x : V2) V_full.push_back(x);
vector<mint> H = shift(V, mint(d) * mint(v).inv(), 2 * d);
vector<mint> next_V(2 * d + 1);
for (int i = 0; i <= 2 * d; i++) {
next_V[i] = V_full[i] * H[i];
}
d = 2 * d;
V = next_V;
// 奇数次数の場合は +1 補正
if ((v >> step) & 1) {
vector<mint> V_next(d + 2);
for (int k = 0; k <= d; k++) {
V_next[k] = V[k] * (mint(k) * v + d);
}
vector<mint> V_d1 = shift(V, d + 1, 0);
V_next[d + 1] = V_d1[0] * (mint(d + 1) * v + d);
d = d + 1;
V = V_next;
}
}
// まとめて階乗を計算
mint fact_n = 1;
for (int k = 1; k <= v; k++) {
fact_n *= V[k];
}
// 端数を掛ける (最大でも √n 回程度)
for (long long i = (long long)v * v + 1; i <= n; i++) {
fact_n *= i;
}
return fact_n;
}
int main() {
// 入出力の高速化
ios_base::sync_with_stdio(false);
cin.tie(NULL);
long long N, K;
if (!(cin >> N >> K)) return 0;
if (K < 0 || K > N) {
cout << 0 << "\n";
return 0;
}
// リュカの定理の準備
long long n0 = N % MOD, n1 = N / MOD;
long long k0 = K % MOD, k1 = K / MOD;
if (k1 > n1) {
cout << 0 << "\n";
return 0;
}
mint ans1 = 1; // 制約上 n1 <= 1 なので、k1 <= n1 ならば組合せは 1 通り
if (k0 > n0) {
cout << 0 << "\n";
return 0;
}
mint ans0 = 0;
if (k0 == 0 || k0 == n0) {
ans0 = 1;
} else if (min(k0, n0 - k0) <= 2000000) {
// 必要計算量が小さい場合は軽量なO(K)ループを使用 (計算負荷の削減)
long long r = min(k0, n0 - k0);
mint num = 1, den = 1;
for (long long i = 1; i <= r; i++) {
num *= (n0 - i + 1);
den *= i;
}
ans0 = num / den;
} else {
// N, Kが巨大な場合は NTTベースの高速階乗アルゴリズムを使用
initFact(100000);
mint fn = fast_fact(n0);
mint fk = fast_fact(k0);
mint fnk = fast_fact(n0 - k0);
ans0 = fn / (fk * fnk);
}
mint final_ans = ans1 * ans0;
cout << final_ans.val() << "\n";
return 0;
}
回転