結果

問題 No.2237 Xor Sum Hoge
ユーザー qwewe
提出日時 2025-05-14 12:58:44
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,528 ms / 10,000 ms
コード長 9,542 bytes
コンパイル時間 1,973 ms
コンパイル使用メモリ 95,344 KB
実行使用メモリ 10,624 KB
最終ジャッジ日時 2025-05-14 13:00:34
合計ジャッジ時間 30,405 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 32
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>
#include <cmath> // Included for completeness, LOGP calculation done manually

// Use long long for intermediate calculations involving large numbers or multiplication
using ll = long long;

// Define the modulus for calculations
const int MOD = 998244353;

// Modular exponentiation: computes (base^exp) % MOD
ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Modular inverse using Fermat's Little Theorem: computes n^(MOD-2) % MOD
// Assumes MOD is prime and n is not a multiple of MOD.
ll modInverse(ll n) {
    // Handle n=0 case if necessary, though it shouldn't occur for factorials or FFT size P
    if (n == 0) return 0; 
    return power(n, MOD - 2);
}


// Namespace for Number Theoretic Transform (NTT) implementation
namespace NTT {
    // Define maximum possible FFT size needed. 
    // Max N=60000 -> product degree 2N=120000 -> requires FFT size P >= 120001.
    // Smallest power of 2 is 2^17 = 131072.
    const int MAX_LEN = 1 << 17; 
    ll G = 3; // Primitive root modulo 998244353
    ll W[MAX_LEN]; // Stores powers of primitive root omega
    ll invW[MAX_LEN]; // Stores powers of inverse primitive root omega^-1
    int rev[MAX_LEN]; // Stores bit reversal permutation indices
    bool roots_precomputed = false; // Flag to ensure roots are computed only once

    // Precompute powers of primitive roots for FFT up to size N_fft
    // N_fft must be a power of 2, and should be MAX_LEN for general use
    void precompute_roots(int N_fft) { 
        if (roots_precomputed) return;
        ll WN = power(G, (MOD - 1) / N_fft); // N_fft-th primitive root of unity
        ll invWN = modInverse(WN); // Inverse root
        W[0] = invW[0] = 1;
        for (int i = 1; i < N_fft; ++i) {
            W[i] = (W[i - 1] * WN) % MOD;
            invW[i] = (invW[i - 1] * invWN) % MOD;
        }
        roots_precomputed = true; 
    }
    
    // Precompute bit reversal indices for a specific FFT size P
    // P must be a power of 2. LOGP is log2(P).
    void precompute_bitrev(int P, int LOGP) {
       // Ensure P is within bounds for the rev array
       if (P > MAX_LEN) { /* Handle error or resize */ return; }
       for (int i = 0; i < P; ++i) {
            rev[i] = 0;
            for (int j = 0; j < LOGP; ++j) {
                if ((i >> j) & 1) {
                    rev[i] |= (1 << (LOGP - 1 - j));
                }
            }
        }
    }

    // Perform NTT or inverse NTT in place on vector `a` of size P
    // P must be a power of 2. `invert` flag determines direction.
    void ntt(std::vector<ll>& a, int P, bool invert) {
        // Ensure vector `a` has size at least P, padding with zeros if necessary
         if (a.size() < P) a.resize(P, 0); 

        // Apply bit reversal permutation using precomputed `rev` array for size P
        for (int i = 0; i < P; ++i) {
             if (i < rev[i]) { 
                std::swap(a[i], a[rev[i]]);
            }
        }

        // Select appropriate roots (W or invW) based on `invert` flag
        ll* roots = invert ? invW : W;
        // The roots W/invW were precomputed for MAX_LEN. We need to use the correct step size.
        for (int len = 2; len <= P; len <<= 1) { // Iterate through transform lengths
            int half_len = len >> 1;
            // Calculate step size into the precomputed roots array (size MAX_LEN)
            int root_step = MAX_LEN / len; 
            for (int i = 0; i < P; i += len) { // Iterate through blocks
                 ll w_idx = 0; // Index into the W/invW array
                 for (int j = 0; j < half_len; ++j) { // Butterfly operations within block
                    ll u = a[i + j];
                    // Multiply with appropriate power of root of unity
                    ll v = (a[i + j + half_len] * roots[w_idx]) % MOD; 
                    // Combine results
                    a[i + j] = (u + v) % MOD;
                    a[i + j + half_len] = (u - v + MOD) % MOD; // Add MOD to ensure non-negative result
                    // Move to the next required root power
                    w_idx = (w_idx + root_step); // No need for % MAX_LEN if indices stay within bounds
                 }
            }
        }

        // If inverse NTT, scale result by 1/P
        if (invert) {
            ll invP = modInverse(P);
            for (int i = 0; i < P; ++i) {
                a[i] = (a[i] * invP) % MOD;
            }
        }
    }
}

// Maximum value for problem parameter N, used for array sizing
const int MAXN_PARAM = 60000 + 5; 
// Precomputed factorials and inverse factorials for combinations
ll fact[MAXN_PARAM];
ll invFact[MAXN_PARAM];

// Precompute factorials and their modular inverses up to N_max
void precompute_combinations(int N_max) { 
    // Ensure N_max does not exceed array bounds
    if (N_max >= MAXN_PARAM) N_max = MAXN_PARAM - 1; 
    fact[0] = 1;
    for (int i = 1; i <= N_max; ++i) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[N_max] = modInverse(fact[N_max]);
    for (int i = N_max - 1; i >= 0; --i) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

// Compute combinations nCk mod MOD using precomputed values
ll combinations(int n, int k) {
    if (k < 0 || k > n) {
        return 0; // Invalid k
    }
    // Ensure n is within the precomputed range
    if (n >= MAXN_PARAM) { /* handle error */ return 0; }
    return (((fact[n] * invFact[k]) % MOD) * invFact[n - k]) % MOD;
}


int main() {
    // Use fast I/O operations
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int N; // Problem parameter N: length of sequence A
    unsigned long long B_val, C_val; // Target sum B and XOR sum C
    std::cin >> N >> B_val >> C_val;

    // Precompute factorials and inverse factorials needed for combinations up to N
    precompute_combinations(N); 

    // Determine required FFT size P: smallest power of 2 such that P > 2N
    // The product polynomial degree can be up to N + N = 2N.
    int P = 1; 
    int LOGP = 0;
    while (P <= 2 * N) { 
        P <<= 1;
        LOGP++;
    }
    
    // Precompute NTT roots for the maximum possible FFT size (MAX_LEN)
    NTT::precompute_roots(NTT::MAX_LEN); 
    // Precompute bit reversal indices specifically for the calculated FFT size P
    NTT::precompute_bitrev(P, LOGP); 

    // Initialize DP state vector `dp_coeffs`. dp_coeffs[c] stores dp[k][c].
    // Initially, dp[0][0] = 1 (1 way to have carry 0 before processing bit 0).
    std::vector<ll> dp_coeffs(N + 1, 0);
    dp_coeffs[0] = 1; 

    // Iterate through bits k from 0 to 59 (since B, C < 2^60)
    for (int k = 0; k < 60; ++k) {
        // Extract k-th bits of B and C
        unsigned long long B_k = (B_val >> k) & 1;
        unsigned long long C_k = (C_val >> k) & 1;

        // Prepare the polynomial representing current DP states for NTT (size P)
        std::vector<ll> current_dp(P, 0); 
        // Determine required parity for carry c_k: c_k % 2 == B_k ^ C_k
        ll required_carry_parity = B_k ^ C_k; 
        // Filter DP states based on carry parity: copy only valid states
        for (int c = 0; c <= N; ++c) {
             if (dp_coeffs[c] != 0) { // Consider only reachable states (non-zero counts)
                  if ((c % 2) == required_carry_parity) { 
                       current_dp[c] = dp_coeffs[c];
                   }
            }
        }

        // Prepare the polynomial representing combinations for NTT (size P)
        // P_N(y) = sum_{x=0..N, x%2 == C_k} C(N, x) * y^x
        std::vector<ll> poly_comb(P, 0); 
        ll required_x_parity = C_k; // x_k must have parity C_k
        for (int x = 0; x <= N; ++x) {
            if ((x % 2) == required_x_parity) {
                poly_comb[x] = combinations(N, x);
            }
        }

        // Transform both polynomials to frequency domain using NTT
        NTT::ntt(current_dp, P, false);
        NTT::ntt(poly_comb, P, false);

        // Perform pointwise multiplication in frequency domain
        std::vector<ll> H_coeffs_ntt(P); // Stores NTT of the product polynomial H(y)
        for (int i = 0; i < P; ++i) {
            H_coeffs_ntt[i] = (current_dp[i] * poly_comb[i]) % MOD;
        }

        // Transform product back to coefficient domain using inverse NTT
        NTT::ntt(H_coeffs_ntt, P, true); 

        // Calculate coefficients for the next DP state dp[k+1] from H(y) coefficients
        std::vector<ll> new_dp_coeffs(N + 1, 0); // Stores dp[k+1][j] values
        for (int j = 0; j <= N; ++j) { // Index j represents the next carry c_{k+1}
             // Check bounds: indices 2j and 2j+1 must be less than P
             ll h_2j = (2 * j < P) ? H_coeffs_ntt[2 * j] : 0;
             ll h_2j_plus_1 = (2 * j + 1 < P) ? H_coeffs_ntt[2 * j + 1] : 0;
             // Update rule: dp[k+1][j] = h_{2j} + h_{2j+1}
             new_dp_coeffs[j] = (h_2j + h_2j_plus_1) % MOD;
             // Ensure result is non-negative modulo MOD (though addition should preserve non-negativity)
             // if (new_dp_coeffs[j] < 0) new_dp_coeffs[j] += MOD; 
        }
        // Update DP state vector for the next bit iteration
        dp_coeffs = new_dp_coeffs; 
    }

    // The final answer is dp[60][0]: number of ways resulting in carry 0 after processing all 60 bits.
    std::cout << dp_coeffs[0] << std::endl;

    return 0;
}
0