結果

問題 No.1388 Less than K
コンテスト
ユーザー vjudge1
提出日時 2026-04-17 23:14:25
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
AC  
実行時間 528 ms / 3,000 ms
コード長 9,956 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,980 ms
コンパイル使用メモリ 233,112 KB
実行使用メモリ 8,796 KB
最終ジャッジ日時 2026-04-17 23:14:42
合計ジャッジ時間 10,219 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 74
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#pragma GCC optimize("O3")
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

const int MOD = 998244353;
const int G = 3;

// --- ????????? ---
int power(int base, int exp) {
    int res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (1LL * res * base) % MOD;
        base = (1LL * base * base) % MOD;
        exp /= 2;
    }
    return res;
}

int modInverse(int n) {
    return power(n, MOD - 2);
}

struct Comb {
    vector<int> fac, ifac;
    void init(int n) {
        fac.assign(n + 1, 1);
        ifac.assign(n + 1, 1);
        for (int i = 1; i <= n; ++i) fac[i] = (1LL * fac[i - 1] * i) % MOD;
        ifac[n] = modInverse(fac[n]);
        for (int i = n; i >= 1; --i) ifac[i - 1] = (1LL * ifac[i] * i) % MOD;
    }
    int C(int n, int k) {
        if (k < 0 || k > n) return 0;
        return (1LL * fac[n] * ifac[k]) % MOD * ifac[n - k] % MOD;
    }
} comb;

// --- NTT ? ???????? ---
void ntt(vector<int>& a, bool invert) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        int wlen = power(G, (MOD - 1) / len);
        if (invert) wlen = modInverse(wlen);
        for (int i = 0; i < n; i += len) {
            int w = 1;
            for (int j = 0; j < len / 2; j++) {
                int u = a[i + j];
                int v = (1LL * a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v >= MOD ? u + v - MOD : u + v);
                a[i + j + len / 2] = (u - v < 0 ? u - v + MOD : u - v);
                w = (1LL * w * wlen) % MOD;
            }
        }
    }
    if (invert) {
        int n_inv = modInverse(n);
        for (int& x : a) x = (1LL * x * n_inv) % MOD;
    }
}

vector<int> poly_mul(vector<int> const& a, vector<int> const& b) {
    if (a.empty() || b.empty()) return {};
    vector<int> fa(a.begin(), a.end()), fb(b.begin(), b.end());
    int n = 1;
    while (n < a.size() + b.size()) n <<= 1;
    fa.resize(n, 0); fb.resize(n, 0);
    ntt(fa, false); ntt(fb, false);
    for (int i = 0; i < n; i++) fa[i] = (1LL * fa[i] * fb[i]) % MOD;
    ntt(fa, true);
    int sz = a.size() + b.size() - 1;
    fa.resize(sz);
    while (fa.size() > 1 && fa.back() == 0) fa.pop_back();
    return fa;
}

vector<int> poly_inv(vector<int> a, int deg) {
    if (deg == 1) return {modInverse(a[0])};
    vector<int> res = poly_inv(a, (deg + 1) / 2);
    int n = 1;
    while (n < deg * 2) n <<= 1;
    vector<int> fa(a.begin(), a.begin() + min((int)a.size(), deg));
    fa.resize(n, 0);
    vector<int> fres = res;
    fres.resize(n, 0);
    ntt(fa, false); ntt(fres, false);
    for (int i = 0; i < n; i++) {
        fres[i] = (2LL - 1LL * fa[i] * fres[i] % MOD + MOD) % MOD * fres[i] % MOD;
    }
    ntt(fres, true);
    fres.resize(deg);
    return fres;
}

void poly_div_mod(vector<int> A, vector<int> B, vector<int>& Q, vector<int>& R) {
    while (A.size() > 1 && A.back() == 0) A.pop_back();
    while (B.size() > 1 && B.back() == 0) B.pop_back();
    int n = A.size(), m = B.size();
    if (n < m) {
        Q = {0}; R = A;
        return;
    }
    vector<int> revA = A, revB = B;
    reverse(revA.begin(), revA.end());
    reverse(revB.begin(), revB.end());
    revB.resize(n - m + 1, 0);
    vector<int> invB = poly_inv(revB, n - m + 1);
    Q = poly_mul(revA, invB);
    Q.resize(n - m + 1, 0);
    reverse(Q.begin(), Q.end());
    vector<int> QB = poly_mul(Q, B);
    R.resize(m - 1, 0);
    for (int i = 0; i < m - 1; i++) {
        R[i] = (A[i] - QB[i]) % MOD;
        if (R[i] < 0) R[i] += MOD;
    }
    while (R.size() > 1 && R.back() == 0) R.pop_back();
}

// --- ???? (Multipoint Evaluation) ---
vector<vector<int>> eval_tree_nodes;

void build_tree(int node, int l, int r, const vector<int>& X) {
    if (l == r) {
        eval_tree_nodes[node] = {(MOD - X[l]) % MOD, 1};
        return;
    }
    int mid = (l + r) / 2;
    build_tree(2 * node, l, mid, X);
    build_tree(2 * node + 1, mid + 1, r, X);
    eval_tree_nodes[node] = poly_mul(eval_tree_nodes[2 * node], eval_tree_nodes[2 * node + 1]);
}

void eval_tree(int node, int l, int r, vector<int> P, vector<int>& res) {
    vector<int> Q, R;
    poly_div_mod(P, eval_tree_nodes[node], Q, R);
    if (l == r) {
        res[l] = R.empty() ? 0 : R[0];
        return;
    }
    int mid = (l + r) / 2;
    eval_tree(2 * node, l, mid, R, res);
    eval_tree(2 * node + 1, mid + 1, r, R, res);
}

// --- ?????? Holonomic BSGS ---
struct PolyMat {
    vector<int> V, U, S;
};

PolyMat combine(const PolyMat& left, const PolyMat& right) {
    PolyMat res;
    res.V = poly_mul(left.V, right.V);
    res.U = poly_mul(left.U, right.U);
    vector<int> p1 = poly_mul(left.V, right.S);
    vector<int> p2 = poly_mul(left.S, right.U);
    res.S.resize(max(p1.size(), p2.size()), 0);
    for (size_t i = 0; i < p1.size(); i++) res.S[i] = (res.S[i] + p1[i]) % MOD;
    for (size_t i = 0; i < p2.size(); i++) res.S[i] = (res.S[i] + p2[i]) % MOD;
    while (res.S.size() > 1 && res.S.back() == 0) res.S.pop_back();
    return res;
}

vector<int> build_poly(const vector<pair<int, int>>& factors) {
    vector<int> res = {1};
    for (auto& f : factors) {
        vector<int> next_res(res.size() + 1, 0);
        for (size_t i = 0; i < res.size(); i++) {
            next_res[i] = (next_res[i] + 1LL * res[i] * f.first) % MOD;
            next_res[i + 1] = (next_res[i + 1] + 1LL * res[i] * f.second) % MOD;
        }
        res = next_res;
    }
    return res;
}

vector<int> shift_poly(const vector<int>& A, int c) {
    if (A.empty()) return {};
    int n = A.size();
    vector<int> res(n, 0);
    vector<int> c_pow(n, 1);
    for (int i = 1; i < n; i++) c_pow[i] = (1LL * c_pow[i - 1] * c) % MOD;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j <= i; j++) {
            long long term = 1LL * A[i] * comb.C(i, j) % MOD;
            term = term * c_pow[i - j] % MOD;
            res[j] = (res[j] + term) % MOD;
        }
    }
    while (res.size() > 1 && res.back() == 0) res.pop_back();
    return res;
}

long long H_base, W_base, N, P;
vector<int> U_poly, V_poly;

PolyMat B_base(int i) {
    PolyMat res;
    res.V = shift_poly(V_poly, i);
    res.U = shift_poly(U_poly, i);
    res.S = res.U; // S ?? M(i) ????????
    return res;
}

PolyMat build_block(int l, int r) {
    if (l == r) return B_base(l);
    int mid = l + (r - l) / 2;
    PolyMat right = build_block(l, mid);
    PolyMat left = build_block(mid + 1, r);
    return combine(left, right); // M_left * M_right
}

long long ceil_div(long long a, long long b) {
    return a >= 0 ? (a + b - 1) / b : a / b;
}

long long floor_div(long long a, long long b) {
    return a >= 0 ? a / b : (a - b + 1) / b;
}

// ?????????????
long long solve_seq(long long offset) {
    long long H = H_base + offset;
    long long W = W_base + offset;
    long long min_kP = max(-H, -W);
    long long max_kP = min(N - H, N - W);
    
    if (min_kP > max_kP) return 0;
    long long L = ceil_div(min_kP, P);
    long long R = floor_div(max_kP, P);
    if (L > R) return 0;
    
    long long M = R - L + 1;
    H = H + L * P;
    W = W + L * P;
    
    // ????????(M??)??? O(M) ???????????????
    if (M <= 1000) {
        long long sum = 0;
        for (int i = 0; i < M; i++) {
            sum = (sum + 1LL * comb.C(N, H + i * P) * comb.C(N, W + i * P)) % MOD;
        }
        return sum;
    }
    
    // --- ?? O(\sqrt{N} \log^2 N) ??????? ---
    int S = max(1, (int)sqrt(M));
    int K = M / S;
    
    vector<pair<int, int>> U_factors, V_factors;
    for (int j = 1; j <= P; j++) {
        U_factors.push_back({((N - H + j) % MOD + MOD) % MOD, (MOD - P % MOD) % MOD});
        U_factors.push_back({((N - W + j) % MOD + MOD) % MOD, (MOD - P % MOD) % MOD});
        V_factors.push_back({((H - P + j) % MOD + MOD) % MOD, P % MOD});
        V_factors.push_back({((W - P + j) % MOD + MOD) % MOD, P % MOD});
    }
    
    U_poly = build_poly(U_factors);
    V_poly = build_poly(V_factors);
    PolyMat B = build_block(0, S - 1);
    
    vector<int> X_pts(K);
    for (int j = 0; j < K; j++) X_pts[j] = (1LL * j * S + 1) % MOD;
    
    eval_tree_nodes.assign(4 * K, vector<int>());
    build_tree(1, 0, K - 1, X_pts);
    vector<int> eval_V(K), eval_U(K), eval_S(K);
    eval_tree(1, 0, K - 1, B.V, eval_V);
    eval_tree(1, 0, K - 1, B.U, eval_U);
    eval_tree(1, 0, K - 1, B.S, eval_S);
    
    long long true_S_val = 1LL * comb.C(N, H) * comb.C(N, W) % MOD;
    long long true_c_val = true_S_val;
    
    // ???? (?? Giant-Step)
    for (int j = 0; j < K; j++) {
        int V_val = eval_V[j], U_val = eval_U[j], S_val = eval_S[j];
        long long next_S = (1LL * V_val * true_S_val + 1LL * S_val * true_c_val) % MOD;
        long long next_c = (1LL * U_val * true_c_val) % MOD;
        long long inv_V = modInverse(V_val);
        true_S_val = next_S * inv_V % MOD;
        true_c_val = next_c * inv_V % MOD;
    }
    
    // ???? (?? Baby-Step)
    for (int i = K * S + 1; i < M; i++) {
        long long term = 1LL * comb.C(N, H + i * P) * comb.C(N, W + i * P) % MOD;
        true_S_val = (true_S_val + term) % MOD;
    }
    
    return true_S_val;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    long long H_in, W_in, K_in;
    if (!(cin >> H_in >> W_in >> K_in)) return 0;
    
    H_base = H_in - 1;
    W_base = W_in - 1;
    N = H_base + W_base;
    long long d = K_in / 2;
    P = 2 * d + 2;
    
    comb.init(N);
    
    if (d == 0) {
        cout << comb.C(N, H_base) << '\n';
        return 0;
    }
    if (d >= min(H_base, W_base)) {
        long long all = comb.C(N, H_base);
        cout << all * all % MOD << '\n';
        return 0;
    }
    
    long long ans = (solve_seq(0) - solve_seq(d + 1) + MOD) % MOD;
    cout << ans << '\n';
    
    return 0;
}
0