結果

問題 No.1025 Modular Equation
ユーザー qwewe
提出日時 2025-05-14 13:13:23
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 9,431 bytes
コンパイル時間 1,621 ms
コンパイル使用メモリ 114,560 KB
実行使用メモリ 6,272 KB
最終ジャッジ日時 2025-05-14 13:14:27
合計ジャッジ時間 9,418 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 6 TLE * 1 -- * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>
#include <complex>
#include <vector>
#include <cmath>
#include <map>
#include <algorithm> // For std::fill

using namespace std;

// Modular exponentiation: computes base^exp % mod
long long power(long long base, long long exp, long long mod) {
    long long res = 1;
    base %= mod;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % mod;
        base = (base * base) % mod;
        exp /= 2;
    }
    return res;
}

// Modular inverse using Fermat's Little Theorem: computes n^(mod-2) % mod
// Requires mod to be prime and n not divisible by mod.
long long modInverse(long long n, long long mod) {
    // Check for trivial case: n=0 has no inverse unless mod=1. But mod is prime p >= 2.
    if (n == 0) return 0; // Or handle as error. In this problem context, A_val is non-zero.
    // Exponent is mod-2. If mod=2, exponent is 0. power(n, 0, mod) = 1. Correct.
    long long M = mod - 2; 
    return power(n, M, mod);
}

// The modulus for the final answer
const int MOD = 1e9 + 7;

// Using std::complex<long double> for FFT to maximize precision with hardware types
using cd = complex<long double>;
// Value of PI for angle calculations in FFT
const long double PI = acosl(-1.0L);

// Iterative Fast Fourier Transform implementation
// Computes DFT (if invert=false) or inverse DFT (if invert=true) of vector 'a'
void fft_iterative(vector<cd>& a, bool invert) {
    int n = a.size();
    // Base case: FFT of size 1 is the element itself
    if (n <= 1) return; 

    // Bit-reversal permutation: reorders elements according to bit-reversed indices
    // This prepares the array for the iterative Cooley-Tukey algorithm
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;

        if (i < j)
            swap(a[i], a[j]);
    }

    // Cooley-Tukey algorithm: iteratively combines DFTs of smaller lengths
    for (int len = 2; len <= n; len <<= 1) {
        // Calculate the principal len-th root of unity (omega_len)
        long double angle_base = 2 * PI / len * (invert ? -1.0L : 1.0L);
        cd wlen(cosl(angle_base), sinl(angle_base));
        // Process segments of length 'len'
        for (int i = 0; i < n; i += len) {
            cd w(1); // Current power of omega_len
            // Combine results from two halves of length len/2
            for (int j = 0; j < len / 2; j++) {
                cd u = a[i + j]; // Element from first half
                cd v = a[i + j + len / 2] * w; // Element from second half multiplied by root of unity
                a[i + j] = u + v; // Butterfly operation: update first half element
                a[i + j + len / 2] = u - v; // Butterfly operation: update second half element
                w *= wlen; // Move to the next power of omega_len
            }
        }
    }

    // If computing inverse DFT, scale results by 1/n
    if (invert) {
        for (cd & x : a) {
            x /= n;
        }
    }
}


// Polynomial multiplication using FFT. Computes (p1 * p2) mod (Z^p - 1).
// The result is the coefficients of the product polynomial modulo Z^p - 1.
// Coefficients are computed modulo MOD.
vector<long long> multiply(const vector<long long>& p1, const vector<long long>& p2, int p) {
    int N = 1;
    // FFT length N must be a power of 2.
    // For cyclic convolution of length p, the linear convolution might have degree up to 2(p-1).
    // The FFT length N must be at least p + p - 1 = 2p - 1 to avoid aliasing in linear convolution sense.
    while (N < 2 * p - 1) N <<= 1; 
    
    vector<cd> fp1(N, 0.0L), fp2(N, 0.0L);
    // Copy coefficients from input polynomials (assumed size p) to complex vectors of size N
    for (int i = 0; i < p; ++i) fp1[i] = p1[i];
    for (int i = 0; i < p; ++i) fp2[i] = p2[i];

    // Compute FFT of both polynomials
    fft_iterative(fp1, false); 
    fft_iterative(fp2, false); 
    
    // Pointwise product in frequency domain corresponds to convolution in time domain
    vector<cd> y_prod(N);
    for (int i = 0; i < N; ++i) y_prod[i] = fp1[i] * fp2[i]; 

    // Compute inverse FFT to get the result polynomial coefficients
    fft_iterative(y_prod, true); 

    vector<long long> res(p, 0);
    // Extract coefficients, round to nearest integer, handle cyclic convolution (mod p indices), and take modulo MOD
    for (int i = 0; i < N; ++i) {
        // Round the real part to the nearest integer. Imaginary part should be close to zero for integer convolution.
        long long val = static_cast<long long>(round(y_prod[i].real()));
        // Add the coefficient value to the appropriate index modulo p (for Z^p - 1 reduction)
        // Perform arithmetic modulo MOD. Add MOD before taking %MOD to handle potential negative values from rounding or intermediate calculations.
        res[i % p] = (res[i % p] + val % MOD + MOD) % MOD; 
    }
    
    return res;
}

// Polynomial exponentiation: computes base^exp mod (Z^p - 1) using binary exponentiation.
// Coefficients modulo MOD.
vector<long long> poly_pow(vector<long long> base, int exp, int p) {
    vector<long long> res(p);
    res[0] = 1; // Initialize result polynomial as 1 (identity for multiplication)

    // Ensure base polynomial vector has size p
    if (base.size() < p) base.resize(p, 0);

    // Binary exponentiation (repeated squaring) adapted for polynomials
    while (exp > 0) {
        if (exp % 2 == 1) res = multiply(res, base, p); // If exponent is odd, multiply result by current base power
        // Square the base polynomial for the next iteration, only if needed (exp > 1)
        if (exp > 1) base = multiply(base, base, p); 
        exp /= 2; // Halve the exponent
    }
    return res;
}


int main() {
    // Use fast I/O
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int p; // Prime modulus for the equation
    int n; // Number of terms in the sum
    long long k; // Exponent for variables x_i
    int b; // Target value for the sum modulo p
    cin >> p >> n >> k >> b;

    // Map to store counts of each distinct non-zero coefficient a_i
    map<long long, int> A_counts_map; 
    int zero_count = 0; // Counter for coefficients a_i that are 0

    // Read coefficients a_i and update counts
    for (int i = 0; i < n; ++i) {
        long long current_a;
        cin >> current_a;
        // Reduce coefficient modulo p
        current_a %= p; 
        if (current_a < 0) current_a += p; // Ensure non-negative value

        if (current_a == 0) {
            zero_count++;
        } else {
            A_counts_map[current_a]++;
        }
    }

    // Reduce target value b modulo p
    b %= p;
    if (b < 0) b += p;

    // Handle the edge case where all coefficients a_i are zero
    if (A_counts_map.empty()) { 
        if (b == 0) {
            // If b=0, the equation is 0 = 0 mod p. This is always true.
            // Each x_i can be any value from 0 to p-1. Total p^n solutions.
            cout << power(p, n, MOD) << endl;
        } else {
            // If b != 0, the equation is 0 = b mod p. This has no solutions.
            cout << 0 << endl;
        }
        return 0;
    }
    
    // Precompute counts of x^k mod p for x in {0, ..., p-1}
    // x_k_counts[y] stores the number of x such that x^k = y mod p
    vector<long long> x_k_counts(p, 0);
    for (int x = 0; x < p; ++x) {
        long long xk_val = power(x, k, p);
        x_k_counts[xk_val]++;
    }
    
    // Initialize the final polynomial representing the convolution result as 1 (identity)
    vector<long long> final_poly(p);
    final_poly[0] = 1; 

    // Compute the polynomial P_A(Z) for each distinct non-zero A
    // P_A(Z) = sum_{y=0}^{p-1} C_A(y) Z^y, where C_A(y) = #{x | A * x^k = y mod p}
    vector<long long> PA(p); // Reusable vector for P_A(Z)
    
    for (auto const& [A_val, count] : A_counts_map) {
        // Reset PA vector for the current coefficient A_val
        fill(PA.begin(), PA.end(), 0); 
        
        // Compute modular inverse of A_val needed for calculating coefficients
        long long A_inv = modInverse(A_val, p);
        // Calculate coefficients C_A(y) for P_A(Z)
        for (int y = 0; y < p; ++y) {
             // We need number of x such that A * x^k = y mod p
             // This is equivalent to x^k = y * A_inv mod p. Let z = y * A_inv mod p.
             // The coefficient C_A(y) is the precomputed count x_k_counts[z].
             long long z = (1LL * y * A_inv) % p;
             PA[y] = x_k_counts[z]; 
        }

        // Compute (P_A(Z))^count using polynomial exponentiation
        vector<long long> PA_pow_u = poly_pow(PA, count, p);
        // Multiply this result into the overall final polynomial
        final_poly = multiply(final_poly, PA_pow_u, p);
    }
    
    // Account for the terms where a_i = 0. Each such term corresponds to choosing any x_i from {0, ..., p-1},
    // effectively contributing a factor of p to the total count for each zero coefficient.
    long long zero_term_factor = power(p, zero_count, MOD);
    // The final answer is the coefficient of Z^b in the final polynomial, multiplied by the factor from zero coefficients.
    long long result = (final_poly[b] * zero_term_factor) % MOD;
    
    // Ensure the final result is non-negative modulo MOD
    if (result < 0) result += MOD;

    cout << result << endl;

    return 0;
}
0