結果

問題 No.1388 Less than K
コンテスト
ユーザー vjudge1
提出日時 2026-04-18 12:24:30
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
WA  
実行時間 -
コード長 10,340 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 4,773 ms
コンパイル使用メモリ 251,916 KB
実行使用メモリ 7,976 KB
最終ジャッジ日時 2026-04-18 12:25:14
合計ジャッジ時間 16,398 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2 WA * 1
other AC * 11 WA * 22 TLE * 1 -- * 40
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

const int MOD = 998244353;
const int G = 3;

// --- ????????? ---
int power(int base, int exp) {
    int res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp & 1) res = (1LL * res * base) % MOD;
        base = (1LL * base * base) % MOD;
        exp >>= 1;
    }
    return res;
}

int modInverse(int n) {
    return power(n, MOD - 2);
}

struct Comb {
    vector<int> fac, ifac;
    void init(int n) {
        fac.assign(n + 1, 1);
        ifac.assign(n + 1, 1);
        for (int i = 1; i <= n; ++i) fac[i] = (1LL * fac[i - 1] * i) % MOD;
        ifac[n] = modInverse(fac[n]);
        for (int i = n; i >= 1; --i) ifac[i - 1] = (1LL * ifac[i] * i) % MOD;
    }
    int C(int n, int k) {
        if (k < 0 || k > n) return 0;
        return (1LL * fac[n] * ifac[k]) % MOD * ifac[n - k] % MOD;
    }
} comb;

// --- ?? NTT ?? (????? + ????) ---
vector<int> rev_arr;
vector<int> roots_arr{0, 1};

void get_rev(int n) {
    if (rev_arr.size() == n) return;
    rev_arr.resize(n);
    for (int i = 1; i < n; i++) {
        rev_arr[i] = (rev_arr[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
    }
}

void ntt(vector<int>& a, int n, int dir) {
    a.resize(n, 0);
    get_rev(n);
    for (int i = 1; i < n; i++) {
        if (i < rev_arr[i]) swap(a[i], a[rev_arr[i]]);
    }
    if (roots_arr.size() < n) {
        int k = __builtin_ctz(roots_arr.size());
        roots_arr.resize(n);
        while ((1 << k) < n) {
            int e = power(G, (MOD - 1) >> (k + 1));
            for (int i = 1 << (k - 1); i < (1 << k); i++) {
                roots_arr[2 * i] = roots_arr[i];
                roots_arr[2 * i + 1] = 1LL * roots_arr[i] * e % MOD;
            }
            k++;
        }
    }
    for (int k = 1; k < n; k <<= 1) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; j++) {
                int u = a[i + j];
                int v = 1LL * a[i + j + k] * roots_arr[k + j] % MOD;
                a[i + j] = u + v >= MOD ? u + v - MOD : u + v;
                a[i + j + k] = u - v < 0 ? u - v + MOD : u - v;
            }
        }
    }
    if (dir == -1) {
        reverse(a.begin() + 1, a.end());
        int inv = modInverse(n);
        for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * inv % MOD;
    }
}

// O(N log N) ???????
vector<int> poly_mul(const vector<int>& a, const vector<int>& b) {
    if (a.empty() || b.empty()) return {};
    vector<int> fa = a, fb = b;
    int n = 1, sz = a.size() + b.size() - 1;
    while (n < sz) n <<= 1;
    ntt(fa, n, 1); ntt(fb, n, 1);
    for (int i = 0; i < n; i++) fa[i] = 1LL * fa[i] * fb[i] % MOD;
    ntt(fa, n, -1);
    fa.resize(sz);
    while (fa.size() > 1 && fa.back() == 0) fa.pop_back();
    return fa;
}

// ??????
vector<int> poly_inv(const vector<int>& a, int deg) {
    vector<int> b(1, modInverse(a[0]));
    for (int len = 2; (len >> 1) < deg; len <<= 1) {
        vector<int> ta(a.begin(), a.begin() + min((int)a.size(), len));
        int n = len << 1;
        ta.resize(n, 0);
        vector<int> tb = b;
        tb.resize(n, 0);
        ntt(ta, n, 1); ntt(tb, n, 1);
        for (int i = 0; i < n; i++) {
            tb[i] = 1LL * tb[i] * (2LL + MOD - 1LL * ta[i] * tb[i] % MOD) % MOD;
        }
        ntt(tb, n, -1);
        tb.resize(len);
        b = tb;
    }
    b.resize(deg);
    return b;
}

// ???????
void poly_div_mod(vector<int> A, vector<int> B, vector<int>& Q, vector<int>& R) {
    while (A.size() > 1 && A.back() == 0) A.pop_back();
    while (B.size() > 1 && B.back() == 0) B.pop_back();
    int n = A.size(), m = B.size();
    if (n < m) { Q = {0}; R = A; return; }
    vector<int> revA = A, revB = B;
    reverse(revA.begin(), revA.end());
    reverse(revB.begin(), revB.end());
    revB.resize(n - m + 1, 0);
    vector<int> invB = poly_inv(revB, n - m + 1);
    Q = poly_mul(revA, invB);
    Q.resize(n - m + 1, 0);
    reverse(Q.begin(), Q.end());
    vector<int> QB = poly_mul(Q, B);
    R.resize(m - 1, 0);
    for (int i = 0; i < m - 1; i++) {
        R[i] = A[i] - QB[i];
        if (R[i] < 0) R[i] += MOD;
    }
    while (R.size() > 1 && R.back() == 0) R.pop_back();
}

// --- O(N log N) ???? P(x+c) ---
// ??? O(N^2) ??????????? NTT ????
vector<int> shift_poly(const vector<int>& a, int c) {
    if (a.empty()) return {};
    int n = a.size();
    vector<int> A(n), B(n);
    int c_pow = 1;
    for (int i = 0; i < n; i++) {
        A[n - 1 - i] = 1LL * a[i] * comb.fac[i] % MOD;
        B[i] = 1LL * c_pow * comb.ifac[i] % MOD;
        c_pow = 1LL * c_pow * c % MOD;
    }
    vector<int> res = poly_mul(A, B);
    vector<int> ret(n);
    for (int i = 0; i < n; i++) {
        ret[i] = 1LL * res[n - 1 - i] * comb.ifac[i] % MOD;
    }
    while (ret.size() > 1 && ret.back() == 0) ret.pop_back();
    return ret;
}

// --- ???? (Multipoint Evaluation) ? ---
vector<vector<int>> eval_tree_nodes;

void build_tree(int node, int l, int r, const vector<int>& X) {
    if (l == r) {
        eval_tree_nodes[node] = {(MOD - X[l]) % MOD, 1};
        return;
    }
    int mid = (l + r) >> 1;
    build_tree(node << 1, l, mid, X);
    build_tree(node << 1 | 1, mid + 1, r, X);
    eval_tree_nodes[node] = poly_mul(eval_tree_nodes[node << 1], eval_tree_nodes[node << 1 | 1]);
}

void eval_tree(int node, int l, int r, vector<int> P, vector<int>& res) {
    vector<int> Q, R;
    poly_div_mod(P, eval_tree_nodes[node], Q, R);
    if (l == r) {
        res[l] = R.empty() ? 0 : R[0];
        return;
    }
    int mid = (l + r) >> 1;
    eval_tree(node << 1, l, mid, R, res);
    eval_tree(node << 1 | 1, mid + 1, r, R, res);
}

// --- ??????? ---
struct PolyMat { vector<int> V, U, S; };

PolyMat combine(const PolyMat& left, const PolyMat& right) {
    PolyMat res;
    res.V = poly_mul(left.V, right.V);
    res.U = poly_mul(left.U, right.U);
    vector<int> p1 = poly_mul(left.V, right.S);
    vector<int> p2 = poly_mul(left.S, right.U);
    res.S.resize(max(p1.size(), p2.size()), 0);
    for (size_t i = 0; i < p1.size(); i++) res.S[i] = (res.S[i] + p1[i]) % MOD;
    for (size_t i = 0; i < p2.size(); i++) res.S[i] = (res.S[i] + p2[i]) % MOD;
    while (res.S.size() > 1 && res.S.back() == 0) res.S.pop_back();
    return res;
}

vector<int> build_poly(const vector<pair<int, int>>& factors) {
    vector<int> res = {1};
    for (auto& f : factors) {
        vector<int> next_res(res.size() + 1, 0);
        for (size_t i = 0; i < res.size(); i++) {
            next_res[i] = (next_res[i] + 1LL * res[i] * f.first) % MOD;
            next_res[i + 1] = (next_res[i + 1] + 1LL * res[i] * f.second) % MOD;
        }
        res = next_res;
    }
    return res;
}

long long H_base, W_base, N, P;
vector<int> U_poly, V_poly;

PolyMat B_base(int i) {
    PolyMat res;
    res.V = shift_poly(V_poly, i);
    res.U = shift_poly(U_poly, i);
    res.S = res.U; 
    return res;
}

PolyMat build_block(int l, int r) {
    if (l == r) return B_base(l);
    int mid = l + (r - l) / 2;
    PolyMat right = build_block(l, mid);
    PolyMat left = build_block(mid + 1, r);
    return combine(left, right);
}

long long solve_seq(long long offset) {
    long long H = H_base + offset;
    long long W = W_base + offset;
    long long min_kP = max(-H, -W);
    long long max_kP = min(N - H, N - W);
    
    if (min_kP > max_kP) return 0;
    long long L = min_kP >= 0 ? (min_kP + P - 1) / P : min_kP / P;
    long long R = max_kP >= 0 ? max_kP / P : (max_kP - P + 1) / P;
    if (L > R) return 0;
    
    long long M = R - L + 1;
    H = H + L * P;
    W = W + L * P;
    
    // ????????? N=1e5 ????????????????????????????????
    // ??????? 2e7???????????????????????
    if (M <= 0) {
        long long sum = 0;
        for (int i = 0; i < M; i++) {
            sum = (sum + 1LL * comb.C(N, H + i * P) * comb.C(N, W + i * P)) % MOD;
        }
        return sum;
    }
    
    // --- ?????????????? BSGS ?? ---
    int S = max(1, (int)sqrt(M));
    int K = M / S;
    
    vector<pair<int, int>> U_factors, V_factors;
    for (int j = 1; j <= P; j++) {
        U_factors.push_back({((N - H + j) % MOD + MOD) % MOD, (MOD - P % MOD) % MOD});
        U_factors.push_back({((N - W + j) % MOD + MOD) % MOD, (MOD - P % MOD) % MOD});
        V_factors.push_back({((H - P + j) % MOD + MOD) % MOD, P % MOD});
        V_factors.push_back({((W - P + j) % MOD + MOD) % MOD, P % MOD});
    }
    
    U_poly = build_poly(U_factors);
    V_poly = build_poly(V_factors);
    PolyMat B = build_block(0, S - 1);
    
    vector<int> X_pts(K);
    for (int j = 0; j < K; j++) X_pts[j] = (1LL * j * S + 1) % MOD;
    
    eval_tree_nodes.assign(4 * K, vector<int>());
    build_tree(1, 0, K - 1, X_pts);
    vector<int> eval_V(K), eval_U(K), eval_S(K);
    eval_tree(1, 0, K - 1, B.V, eval_V);
    eval_tree(1, 0, K - 1, B.U, eval_U);
    eval_tree(1, 0, K - 1, B.S, eval_S);
    
    long long true_S_val = 1LL * comb.C(N, H) * comb.C(N, W) % MOD;
    long long true_c_val = true_S_val;
    
    for (int j = 0; j < K; j++) {
        int V_val = eval_V[j], U_val = eval_U[j], S_val = eval_S[j];
        long long next_S = (1LL * V_val * true_S_val + 1LL * S_val * true_c_val) % MOD;
        long long next_c = (1LL * U_val * true_c_val) % MOD;
        long long inv_V = modInverse(V_val);
        true_S_val = next_S * inv_V % MOD;
        true_c_val = next_c * inv_V % MOD;
    }
    
    for (int i = K * S; i < M; i++) {
        long long term = 1LL * comb.C(N, H + i * P) * comb.C(N, W + i * P) % MOD;
        true_S_val = (true_S_val + term) % MOD;
    }
    
    return true_S_val;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    long long H_in, W_in, K_in;
    if (!(cin >> H_in >> W_in >> K_in)) return 0;
    
    H_base = H_in - 1;
    W_base = W_in - 1;
    N = H_base + W_base;
    long long d = K_in / 2;
    P = 2 * d + 2;
    
    comb.init(N);
    
    if (d == 0) {
        cout << comb.C(N, H_base) << '\n';
        return 0;
    }
    if (d >= min(H_base, W_base)) {
        long long all = comb.C(N, H_base);
        cout << all * all % MOD << '\n';
        return 0;
    }
    
    long long ans = (solve_seq(0) - solve_seq(d + 1) + MOD) % MOD;
    cout << ans << '\n';
    
    return 0;
}
0