結果

問題 No.1762 🐙🐄🌲
ユーザー vjudge1
提出日時 2025-10-03 02:26:39
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 8,890 bytes
コンパイル時間 1,792 ms
コンパイル使用メモリ 125,912 KB
実行使用メモリ 20,908 KB
最終ジャッジ日時 2025-10-03 02:26:53
合計ジャッジ時間 6,447 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1 WA * 2
other AC * 20 WA * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

// ??
const int MOD = 998244353;
// ??
const int G = 3;

// ??? (Modular Exponentiation)
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// ??? (Modular Inverse)
long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// ?????????
vector<long long> fact;
vector<long long> invFact;

// ?????????
void precompute_factorials(int n) {
    fact.resize(n + 1);
    invFact.resize(n + 1);
    fact[0] = 1;
    for (int i = 1; i <= n; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[n] = modInverse(fact[n]);
    for (int i = n - 1; i >= 0; i--) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

// ??? C(n, r)
long long nCr(int n, int r) {
    if (r < 0 || r > n) return 0;
    return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}

// --- NTT (Number Theoretic Transform) Implementation ---

// ????? (Bit-reversal permutation)
void bit_reverse(vector<long long>& a, int n) {
    int j = 0;
    for (int i = 1; i < n; i++) {
        int bit = n >> 1;
        while (j & bit) {
            j ^= bit;
            bit >>= 1;
        }
        j ^= bit;
        if (i < j) {
            swap(a[i], a[j]);
        }
    }
}

// NTT ???
// invert: true for Inverse NTT
void ntt(vector<long long>& a, bool invert) {
    int n = a.size();
    // 1. ?????
    bit_reverse(a, n);

    // 2. ???? (Butterfly operations)
    for (int len = 2; len <= n; len <<= 1) {
        // len ??????????
        long long wlen = power(G, (MOD - 1) / len);
        if (invert) wlen = modInverse(wlen);
        for (int i = 0; i < n; i += len) {
            long long w = 1;
            for (int j = 0; j < len / 2; j++) {
                long long u = a[i + j];
                long long v = (a[i + j + len / 2] * w) % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len / 2] = (u - v + MOD) % MOD;
                w = (w * wlen) % MOD;
            }
        }
    }

    // 3. ? NTT ????
    if (invert) {
        long long n_inv = modInverse(n);
        for (long long& x : a) {
            x = (x * n_inv) % MOD;
        }
    }
}

// ????? (Polynomial Multiplication)
vector<long long> poly_mul(vector<long long> a, vector<long long> b) {
    int deg = a.size() + b.size() - 1;
    int n = 1;
    while (n < deg) n <<= 1;
    a.resize(n, 0);
    b.resize(n, 0);

    ntt(a, false);
    ntt(b, false);

    vector<long long> res(n);
    for (int i = 0; i < n; i++) {
        res[i] = (a[i] * b[i]) % MOD;
    }

    ntt(res, true);
    res.resize(deg);
    return res;
}

// ????? (Polynomial Inverse)
// ?? A^{-1} mod x^n
vector<long long> poly_inv(const vector<long long>& a, int n) {
    if (n == 1) return {modInverse(a[0])};
    
    // ??? A_0^{-1} mod x^(n/2)
    vector<long long> a0_inv = poly_inv(a, (n + 1) / 2);
    
    int m = 1;
    while (m < 2 * n) m <<= 1;

    // A mod x^n
    vector<long long> A(a.begin(), a.begin() + min((int)a.size(), n));
    A.resize(m, 0);
    // A_0^{-1} mod x^n
    a0_inv.resize(m, 0);

    ntt(A, false);
    ntt(a0_inv, false);

    vector<long long> B(m);
    for (int i = 0; i < m; i++) {
        // B = (2 * A_0^{-1} - A * (A_0^{-1})^2) mod x^n
        // B = A_0^{-1} * (2 - A * A_0^{-1}) mod x^n
        long long term = (A[i] * a0_inv[i]) % MOD;
        long long factor = (2 - term + MOD) % MOD;
        B[i] = (a0_inv[i] * factor) % MOD;
    }
    
    ntt(B, true);
    B.resize(n);
    return B;
}

// ????? (Polynomial Derivative)
vector<long long> poly_der(const vector<long long>& a) {
    if (a.empty() || a.size() == 1) return {};
    vector<long long> res(a.size() - 1);
    for (int i = 1; i < a.size(); i++) {
        res[i - 1] = (a[i] * i) % MOD;
    }
    return res;
}

// ????? (Polynomial Integral)
vector<long long> poly_int(const vector<long long>& a) {
    if (a.empty()) return {0};
    vector<long long> res(a.size() + 1, 0);
    for (int i = 0; i < a.size(); i++) {
        res[i + 1] = (a[i] * modInverse(i + 1)) % MOD;
    }
    return res;
}

// ????? (Polynomial Logarithm)
// ?? ln(A) mod x^n. ?? A[0] = 1
vector<long long> poly_ln(const vector<long long>& a, int n) {
    if (a[0] != 1) return {}; // ?? A[0] = 1
    
    // ln(A) = integral(A' * A^{-1})
    vector<long long> a_inv = poly_inv(a, n); // A^{-1} mod x^n
    vector<long long> a_der = poly_der(a);   // A'
    
    // ?? size ??????
    int size = 1;
    while (size < n + a_der.size()) size <<= 1;
    
    vector<long long> a_inv_ntt = a_inv;
    a_inv_ntt.resize(size, 0);
    vector<long long> a_der_ntt = a_der;
    a_der_ntt.resize(size, 0);
    
    ntt(a_inv_ntt, false);
    ntt(a_der_ntt, false);
    
    vector<long long> mul_res(size);
    for (int i = 0; i < size; i++) {
        mul_res[i] = (a_der_ntt[i] * a_inv_ntt[i]) % MOD;
    }
    
    ntt(mul_res, true);
    mul_res.resize(n - 1); // ?? degree N-2
    
    return poly_int(mul_res); // ??? degree N-1
}

// ????? (Polynomial Exponentiation)
// ?? exp(A) mod x^n. ?? A[0] = 0
vector<long long> poly_exp(const vector<long long>& a, int n) {
    if (n == 1) return {1};
    
    // ??? H_0 = exp(A) mod x^(n/2)
    vector<long long> h0 = poly_exp(a, (n + 1) / 2);
    
    // L = ln(H_0) mod x^n
    vector<long long> L = poly_ln(h0, n);
    
    // R = A - L mod x^n
    vector<long long> R(n);
    for (int i = 0; i < n; i++) {
        long long a_i = i < a.size() ? a[i] : 0;
        long long l_i = i < L.size() ? L[i] : 0;
        R[i] = (a_i - l_i + MOD) % MOD;
    }
    R[0] = (R[0] + 1) % MOD; // H = H_0 * (1 + A - ln(H_0)) mod x^n

    // ????? H = H_0 * R mod x^n
    int size = 1;
    while (size < 2 * n) size <<= 1;
    
    vector<long long> h0_ntt = h0;
    h0_ntt.resize(size, 0);
    vector<long long> R_ntt = R;
    R_ntt.resize(size, 0);
    
    ntt(h0_ntt, false);
    ntt(R_ntt, false);

    vector<long long> res(size);
    for (int i = 0; i < size; i++) {
        res[i] = (h0_ntt[i] * R_ntt[i]) % MOD;
    }
    
    ntt(res, true);
    res.resize(n);
    return res;
}

// ?????? (Polynomial Exponentiation G^Q)
// G(x)^Q mod x^n
vector<long long> poly_q_pow(const vector<long long>& G, long long Q, int n) {
    // ?? G[0] = 1
    // G^Q = exp(Q * ln(G))
    
    // 1. ?? ln(G) mod x^n
    vector<long long> ln_G = poly_ln(G, n);
    
    // 2. ?? Q * ln(G) mod x^n
    vector<long long> A(n, 0);
    for (int i = 0; i < ln_G.size() && i < n; i++) {
        // ln_G[0] = 0, ?? A[0] = 0
        A[i] = (ln_G[i] * Q) % MOD;
    }
    
    // 3. ?? exp(A) mod x^n
    return poly_exp(A, n);
}


void solve() {
    int N;
    long long P;
    cin >> N >> P;

    // ?? N ? 5e5??????
    precompute_factorials(N);

    // --- 1. ?????? ---
    // N ???? N ? 1 (mod 4)
    if (N % 4 != 1) {
        cout << 0 << endl;
        return;
    }

    // ?? N_U (????) ? N_T (?????)
    long long N_U = (N - 1) / 4;
    long long N_T = (3LL * N + 1) / 4;

    // Q ???????
    long long Q = N_T - P;

    // K ? c_i = d_i - 1 ???
    // K = sum(d_i) - Q = (N-1 - 8P) - Q
    // K = N - 1 - 8P - (N_T - P) = N - 1 - 7P - N_T
    long long K_val = N - 1 - 7 * P - N_T;

    // P ??????P <= (N-5)/28
    // K >= 0 ?????
    if (K_val < 0) {
        cout << 0 << endl;
        return;
    }
    
    // Q ????Q >= 0
    if (Q < 0) {
        cout << 0 << endl;
        return;
    }

    // c_i <= 6 ????K <= 6Q
    if (K_val > 6 * Q) {
        cout << 0 << endl;
        return;
    }

    int K = (int)K_val;

    // --- 2. ????? ---
    long long ans = 1;

    // ????: C(N, N_T) * C(N_T, P)
    ans = (ans * nCr(N, N_T)) % MOD;
    ans = (ans * nCr((int)N_T, (int)P)) % MOD;

    // ???: (N-2)! / ((7!)^P * (3!)^N_U)
    // (N-2)!
    if (N >= 2) {
        ans = (ans * fact[N - 2]) % MOD;
    } else { // N=1, N=2 ????? N ? 1 (mod 4)
        // N >= 5 
    }

    // 7! ? 3! ???
    long long inv7Fact = invFact[7];
    long long inv3Fact = invFact[3];
    
    // (7!)^{-P}
    ans = (ans * power(inv7Fact, P)) % MOD;
    // (3!)^{-N_U}
    ans = (ans * power(inv3Fact, N_U)) % MOD;

    // --- 3. ????? T_Q ??? ---
    
    // G(x) = sum_{c=0}^6 x^c / c!
    vector<long long> G(7);
    for (int c = 0; c <= 6; c++) {
        G[c] = invFact[c];
    }

    // Q = N_T - P
    // T_Q = [x^K] G(x)^Q
    
    // ???????? H(x) = G(x)^Q mod x^{K+1}
    // ???????? n = K + 1
    int n = K + 1;
    vector<long long> H = poly_q_pow(G, Q, n);

    // ???? T_Q = H[K]
    long long T_Q = H[K];

    // --- 4. ???? ---
    ans = (ans * T_Q) % MOD;

    cout << ans << endl;
}

// ???
int main() {
    // ??????? I/O ??
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    solve();

    return 0;
}
0