結果

問題 No.1762 🐙🐄🌲
ユーザー qwewe
提出日時 2025-05-14 13:06:17
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,367 ms / 4,000 ms
コード長 9,720 bytes
コンパイル時間 1,224 ms
コンパイル使用メモリ 82,516 KB
実行使用メモリ 28,392 KB
最終ジャッジ日時 2025-05-14 13:07:58
合計ジャッジ時間 32,021 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 47
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm> // for std::swap, std::min

// --- Modular Arithmetic ---
const int MOD = 998244353;

// Computes base^exp % MOD efficiently
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Computes modular inverse using Fermat's Little Theorem
long long modInverse(long long n) {
    // Assumes n % MOD != 0. This is safe for this problem based on constraints analysis.
    return power(n, MOD - 2);
}


// --- Number Theoretic Transform (NTT) ---
// Use NTT_LOG = 18, so NTT_SIZE = 2^18 = 262144.
// This size is sufficient because max K' is around 1.25e5, and NTT multiplication needs size > 2*K'. 2^18 > 2 * 1.25e5.
const int NTT_LOG = 18; 
const int NTT_SIZE = 1 << NTT_LOG; 
const int G = 3; // Primitive root for 998244353

long long W[NTT_SIZE], W_inv[NTT_SIZE]; // Roots of unity and their inverses
int rev[NTT_SIZE]; // Bit reversal permutation indices

// Precompute roots of unity and bit reversal permutation indices for NTT
void precompute_ntt() {
    long long G_pow = power(G, (MOD - 1) / NTT_SIZE); // G^((MOD-1)/NTT_SIZE)
    long long G_inv_pow = modInverse(G_pow); // Inverse root
    W[0] = W_inv[0] = 1;
    for (int i = 1; i < NTT_SIZE; ++i) {
        W[i] = (W[i - 1] * G_pow) % MOD;
        W_inv[i] = (W_inv[i - 1] * G_inv_pow) % MOD;
    }

    // Precompute bit reversal permutation indices
    for (int i = 0; i < NTT_SIZE; ++i) {
         rev[i] = 0;
         for(int j=0; j<NTT_LOG; ++j) {
              if ((i >> j) & 1) { // Check j-th bit of i
                  rev[i] |= (1 << (NTT_LOG - 1 - j)); // Set (LOG-1-j)-th bit of rev[i]
              }
         }
    }
}


// Performs NTT or Inverse NTT
// `a` is the vector of coefficients, `invert` flag determines forward (false) or inverse (true) transform.
// Assumes `a` has size NTT_SIZE (or padded to it).
void ntt(std::vector<long long>& a, bool invert) {
    int n = NTT_SIZE; // Use fixed NTT size
    if (a.size() < n) a.resize(n, 0); // Pad with zeros if smaller than NTT_SIZE

    // Apply bit reversal permutation
    for (int i = 0; i < n; ++i) {
        if (i < rev[i]) {
            std::swap(a[i], a[rev[i]]);
        }
    }

    long long* roots = invert ? W_inv : W; // Choose roots based on direction

    // Butterfly operations
    for (int len = 2; len <= n; len <<= 1) { // Iterate through lengths 2, 4, ..., n
        int step = NTT_SIZE / len; // Step size to pick roots
        int half_len = len >> 1;
        for (int i = 0; i < n; i += len) { // Iterate through blocks
            for (int j = 0; j < half_len; j++) { // Iterate within block halves
                 long long w = roots[j * step]; // Get appropriate root of unity
                 long long u = a[i + j];
                 long long v = (a[i + j + half_len] * w) % MOD;
                 a[i + j] = (u + v) % MOD; // Combine results
                 a[i + j + half_len] = (u - v + MOD) % MOD; // Ensure positive result
            }
        }
    }

    // Scale by 1/n if inverse transform
    if (invert) {
        long long n_inv = modInverse(n);
        for (int i=0; i<n; ++i) {
            a[i] = (a[i] * n_inv) % MOD;
        }
    }
}

// Multiplies two polynomials `a` and `b` using NTT.
// Returns the resulting polynomial truncated to degree `result_len - 1`.
// `result_len` is the number of coefficients (degree + 1).
std::vector<long long> multiply(const std::vector<long long>& a, const std::vector<long long>& b, int result_len) {
    // Copy input polynomials and resize to NTT_SIZE, padding with zeros
    std::vector<long long> fa(a); 
    fa.resize(NTT_SIZE, 0); 
    std::vector<long long> fb(b); 
    fb.resize(NTT_SIZE, 0); 
    
    ntt(fa, false); // Forward NTT on fa
    ntt(fb, false); // Forward NTT on fb

    // Pointwise multiplication in frequency domain
    std::vector<long long> result(NTT_SIZE);
    for (int i = 0; i < NTT_SIZE; i++)
        result[i] = (fa[i] * fb[i]) % MOD;

    ntt(result, true); // Inverse NTT to get coefficients
    
    result.resize(result_len); // Truncate result to the required length
    return result;
}

// Computes polynomial exponentiation: base^exp mod x^result_len
// `result_len` is the maximum number of coefficients needed.
std::vector<long long> poly_pow(std::vector<long long> base, long long exp, int result_len) {
    std::vector<long long> res(result_len); // Result polynomial initialized to 0
    if (result_len > 0) res[0] = 1; else return {}; // If result_len is 0, return empty. Otherwise, start with polynomial 1.

    // Ensure base polynomial does not exceed required length
    if (base.size() > result_len) base.resize(result_len); 

    // Standard binary exponentiation (exponentiation by squaring)
    while (exp > 0) {
        if (exp % 2 == 1) { // If exponent is odd
            res = multiply(res, base, result_len); // Multiply result by base
        }
        // Square base for next iteration, only if needed
        if (exp > 1) { 
           base = multiply(base, base, result_len);
        }
        exp /= 2; // Divide exponent by 2
    }
    res.resize(result_len); // Ensure final result has the correct size
    return res;
}

// --- Factorials ---
const int MAX_N_Factorial = 500001; // Maximum N is 5e5, need factorials up to N
long long fact[MAX_N_Factorial]; // fact[i] = i! mod MOD
long long invFact[MAX_N_Factorial]; // invFact[i] = (i!)^-1 mod MOD

// Precompute factorials and their modular inverses up to MAX_N_Factorial - 1
void precompute_factorials_optimized() {
    fact[0] = 1;
    for (int i = 1; i < MAX_N_Factorial; ++i) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    // Compute inverse of N! using Fermat's Little Theorem
    invFact[MAX_N_Factorial - 1] = modInverse(fact[MAX_N_Factorial - 1]);
    // Compute other inverse factorials iteratively: invFact[i] = invFact[i+1] * (i+1)
    for (int i = MAX_N_Factorial - 2; i >= 0; --i) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

// --- Main Logic ---
int main() {
    std::ios_base::sync_with_stdio(false); // Faster I/O
    std::cin.tie(NULL);

    int N;
    long long P; // P can be up to N, use long long
    std::cin >> N >> P;

    // Condition check: N must be at least 2 and N-1 must be divisible by 4
    // If N=1 mod 4 doesn't hold, or N < 2, no such tree exists.
    // The smallest N satisfying N>=2 and N=1 mod 4 is N=5.
    if ((N - 1) % 4 != 0 || N < 2) { 
        std::cout << 0 << std::endl;
        return 0;
    }

    // Calculate number of Octopus (N_O) and Cow (N_C) vertices
    long long N_O = (3LL * N + 1) / 4;
    long long N_C = (N - 1) / 4;

    // K is related to the sum of (degree-1) for Octopus vertices
    long long K = (N - 5) / 4; 
    // Check if K is valid. For N=5, K=0. Minimal N is 5, so K >= 0.

    // K' is the target sum of (degree-1) for non-Perfect Octopus vertices
    long long K_prime = K - 7LL * P;

    // If K' is negative, it's impossible to satisfy degree sum constraints.
    if (K_prime < 0) {
        std::cout << 0 << std::endl;
        return 0;
    }
    // P must be non-negative and cannot exceed the number of Octopus vertices.
    if (P < 0 || P > N_O) { 
        std::cout << 0 << std::endl;
        return 0;
    }

    // Precompute necessary values
    precompute_factorials_optimized();
    precompute_ntt(); 

    // Define the base polynomial f(x) = sum_{k=0..6} x^k / k!
    std::vector<long long> f(7); 
    for (int k = 0; k <= 6; ++k) {
         f[k] = invFact[k];
    }

    // Number of non-Perfect Octopus vertices
    long long N_O_minus_P = N_O - P;
    
    std::vector<long long> g; // Result of polynomial exponentiation
    int required_poly_len = K_prime + 1; // Need coefficient of x^K', so need poly up to degree K'

    // Check if required polynomial length exceeds NTT capability (sanity check)
    if (required_poly_len > NTT_SIZE) {
         std::cerr << "Error: Required polynomial length " << required_poly_len << " exceeds NTT size " << NTT_SIZE << "." << std::endl;
         return 1; // Should not happen based on constraints
    }
    
    // Compute g(x) = f(x)^(N_O - P)
    if (N_O_minus_P == 0) { // Special case: exponent is 0
         g.resize(required_poly_len, 0);
         if (K_prime == 0) g[0] = 1; // f(x)^0 = 1. Coeff of x^0 is 1. All others 0.
    } else {
         g = poly_pow(f, N_O_minus_P, required_poly_len);
    }

    // Extract the coefficient C_{K'} = [x^{K'}] g(x)
    long long C_K_prime = 0;
    if (K_prime < g.size()) { // Check index boundary
        C_K_prime = g[K_prime];
    } 
    // If K_prime >= g.size(), means K_prime >= required_poly_len. This coefficient is implicitly 0.
    
    // Precompute modular inverses of constants 6 and 5040 (7!)
    long long inv6 = modInverse(6);
    long long inv5040 = modInverse(5040); 

    // Calculate the final answer using the derived formula:
    // TotalCount = (N! * (N_O-1)!) / (N_C * 6^N_C * P! * (N_O-P)! * 5040^P) * C_{K'}
    long long ans = fact[N];
    ans = (ans * fact[N_O - 1]) % MOD;
    // N_C = (N-1)/4. Since N>=5, N_C >= 1. So N_C is non-zero.
    ans = (ans * modInverse(N_C)) % MOD; 
    ans = (ans * power(inv6, N_C)) % MOD; // (1/6)^N_C
    ans = (ans * invFact[P]) % MOD; // 1/P!
    // P <= N_O ensures N_O - P >= 0. Index is valid.
    ans = (ans * invFact[N_O - P]) % MOD; // 1/((N_O-P)!)
    ans = (ans * power(inv5040, P)) % MOD; // (1/5040)^P
    ans = (ans * C_K_prime) % MOD; // Multiply by the coefficient computed

    std::cout << ans << std::endl;

    return 0;
}
0