結果

問題 No.856 増える演算
ユーザー qwewe
提出日時 2025-05-14 13:08:10
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 11,100 bytes
コンパイル時間 986 ms
コンパイル使用メモリ 108,736 KB
実行使用メモリ 22,176 KB
最終ジャッジ日時 2025-05-14 13:09:59
合計ジャッジ時間 9,006 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 56 WA * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <complex>
#include <cmath> // For round, acos

// Define MOD constant
const long long MOD = 1000000007;
// Define MOD-1 constant, used for exponents in modular exponentiation via Fermat's Little Theorem
const long long MOD_MINUS_1 = 1000000006;

// Modular exponentiation: computes base^exp % MOD
// Handles base correctly modulo MOD. Handles exponent using Fermat's Little Theorem property.
// Assumes exp >= 0. Base can be anything.
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    // Ensure base is non-negative after initial modulo
    if (base < 0) base += MOD; 
    
    // Handle 0^0 = 1. If base is 0 and exp is 0, result is 1. 
    // If base is 0 and exp > 0, result is 0.
    if (base == 0) return (exp == 0) ? 1 : 0;

    // The exponent used in calculation is exp % MOD_MINUS_1 based on Fermat's Little Theorem.
    // However, if exp is a multiple of MOD_MINUS_1, exp % MOD_MINUS_1 = 0.
    // In this case, the effective exponent should be MOD_MINUS_1 if exp > 0.
    // If exp = 0, effective exponent is 0. power(base, 0) correctly yields 1.
    // This function takes `exp` already processed. So we compute base^exp mod MOD directly.
    
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

// Modular inverse using Fermat's Little Theorem: computes n^(MOD-2) % MOD
// Assumes n is not divisible by MOD.
long long modInverse(long long n) {
    n %= MOD;
    if (n < 0) n += MOD;
    // Inverse of 0 doesn't exist. Problem constraints A_i >= 1 ensure V_min > 0.
    // We also proved V_min % MOD != 0.
    if (n == 0) {
       // This indicates an error state. Return 1 as a fallback might hide issues.
       // Let's trust the analysis that n != 0.
       // Consider throwing exception or assert if necessary for debugging.
       return 1; // Placeholder, potentially problematic
    }
    // Use Fermat's Little Theorem: n^(MOD-2) is the inverse of n mod MOD
    return power(n, MOD - 2);
}

// FFT implementation using complex numbers
const double PI = acos(-1.0);

// Performs FFT or Inverse FFT based on `invert` flag
void fft(std::vector<std::complex<double>>& a, bool invert) {
    int n = a.size();
    if (n == 1) return; // Base case for recursion

    // Bit-reversal permutation could be implemented here for iterative FFT.
    // Using simpler recursive version: split into even and odd indexed elements
    std::vector<std::complex<double>> a0(n / 2), a1(n / 2);
    for (int i = 0; 2 * i < n; i++) {
        a0[i] = a[2 * i];
        a1[i] = a[2 * i + 1];
    }
    // Recursively compute FFT for halves
    fft(a0, invert);
    fft(a1, invert);

    // Combine results using butterfly operations
    double ang = 2 * PI / n * (invert ? -1 : 1); // Angle for roots of unity depends on direction
    std::complex<double> w(1), wn(cos(ang), sin(ang)); // w is current root of unity, wn is the primitive nth root
    for (int i = 0; 2 * i < n; i++) {
        std::complex<double> t = w * a1[i]; // Multiply odd part result by root of unity
        a[i] = a0[i] + t; // Combine even and odd parts
        a[i + n / 2] = a0[i] - t; // Combine even and odd parts
        // For Inverse FFT, divide by 2 at each stage (equivalent to dividing by N at the end)
        if (invert) {
             a[i] /= 2;
             a[i + n / 2] /= 2;
        }
        w *= wn; // Move to the next root of unity
    }
}


// Polynomial multiplication using FFT. Returns coefficients of A(x) * B(x).
std::vector<long long> multiply(const std::vector<long long>& a, const std::vector<long long>& b) {
    // Convert coefficient vectors to complex vectors
    std::vector<std::complex<double>> fa(a.size()), fb(b.size());
    for(size_t i=0; i<a.size(); ++i) fa[i] = a[i];
    for(size_t i=0; i<b.size(); ++i) fb[i] = b[i];
    
    // Determine FFT size (power of 2 >= result degree + 1)
    int n = 1;
    // Result degree is (a.size()-1) + (b.size()-1). Result length is sum of degrees + 1 = a.size() + b.size() - 1.
    while (n < a.size() + b.size() - 1) n <<= 1; 
    // Resize vectors to FFT size, padding with zeros
    fa.resize(n);
    fb.resize(n);

    // Compute FFT of both polynomials
    fft(fa, false); 
    fft(fb, false); 
    // Pointwise multiply in frequency domain
    for (int i = 0; i < n; i++) fa[i] *= fb[i]; 
    // Compute Inverse FFT to get result polynomial coefficients
    fft(fa, true); 

    // The result polynomial has length a.size() + b.size() - 1
    std::vector<long long> result(a.size() + b.size() - 1); 
    for (size_t i = 0; i < result.size(); i++) {
         // Ensure index i is within bounds of fa (which has size n)
         if (i < fa.size()) {
             // Convert complex result back to long long, rounding to handle potential precision errors
             result[i] = static_cast<long long>(round(fa[i].real())); 
         } else {
             result[i] = 0; // Should not happen if n calculation is correct, but safe fallback
         }
    }
    return result;
}

int main() {
    std::ios_base::sync_with_stdio(false); // Faster I/O
    std::cin.tie(NULL);

    int N; // Number of integers
    std::cin >> N;
    std::vector<long long> A(N); // Input array A
    int V_max = 0; // Maximum value in A, needed for polynomial size
    for (int i = 0; i < N; ++i) {
        std::cin >> A[i];
        if (A[i] > V_max) {
            V_max = A[i];
        }
    }

    // Compute frequency counts of values in A
    std::vector<long long> counts(V_max + 1, 0);
    for (long long val : A) {
        // Check if value is within expected range [1, V_max]
        if (val >= 1 && val <= V_max) { 
           counts[val]++;
        }
    }

    // Part 1: Compute P_sum = product of (A_i + A_j) for i < j
    
    // Construct polynomial P(x) = sum counts[v] * x^v. The degree is V_max. Length is V_max+1.
    std::vector<long long> P_poly(V_max + 1, 0);
    for (int v = 1; v <= V_max; ++v) {
        P_poly[v] = counts[v];
    }

    // Compute P(x)^2 using FFT based polynomial multiplication
    // The result contains coefficients N''_S for x^S, where N''_S = |{(i, j) | A_i + A_j = S}|
    std::vector<long long> P_poly_sq_coeffs = multiply(P_poly, P_poly);
    
    long long P_sum = 1; // Initialize product P_sum
    // The result P_poly_sq_coeffs has indices up to 2 * V_max. Its size is V_max+1 + V_max+1 - 1 = 2*V_max + 1.
    int max_S = P_poly_sq_coeffs.size() - 1; // Maximum index S in the result vector

    // Iterate through possible sums S from 2 up to 2*V_max
    for (int S = 2; S <= max_S; ++S) {
         // Check bounds and if coefficient is zero
         if (S >= P_poly_sq_coeffs.size() || P_poly_sq_coeffs[S] == 0) continue;
         
        long long N_double_prime_S = P_poly_sq_coeffs[S]; // Coefficient of x^S in P(x)^2
        
        // Calculate N_S = |{(i, j) | i < j, A_i + A_j = S}|
        // Use formula 2 * N_S = N''_S - |{(i, i) | A_i + A_i = S}|
        long long term_cnt = N_double_prime_S;
        if (S % 2 == 0) { // Check diagonal contribution only if S is even
            int v = S / 2;
            // Check if v is a valid value within the range [1, V_max]
            if (v >= 1 && v <= V_max) {
                 term_cnt -= counts[v]; // Subtract count of elements equal to v = S/2
            }
        }
        
        // term_cnt = 2 * N_S. It's guaranteed to be even.
        long long N_S = term_cnt / 2;

        if (N_S == 0) continue; // If count N_S is 0, S^0 = 1, doesn't affect product

        // Compute S^(N_S) mod MOD. The exponent needs to be taken modulo MOD-1.
        long long exponent = N_S;
        long long E = exponent % MOD_MINUS_1;
        // E will be >= 0 since exponent N_S >= 0.
        
        long long S_pow_E_S = power(S, E); // Calculate S^E mod MOD
        P_sum = (P_sum * S_pow_E_S) % MOD; // Multiply into running product P_sum
    }
    // Ensure P_sum is positive
     if (P_sum < 0) P_sum += MOD;

    // Part 2: Compute P_pow = product of A_i^(S_{i+1}) for i=0..N-2
    // S_{i+1} = sum(A_j for j=i+1..N-1)
    
    // Compute suffix sums of A modulo MOD-1
    std::vector<long long> suffix_sum_mod(N + 1, 0);
    for (int i = N - 1; i >= 0; --i) {
        // Calculate sum carefully to handle potential intermediate overflow if N*V_max > 2^63
        // N*V_max <= 10^5 * 10^5 = 10^10, fits in long long.
        suffix_sum_mod[i] = (suffix_sum_mod[i+1] + A[i]);
         // Perform modulo operation
        if (suffix_sum_mod[i] >= MOD_MINUS_1 || suffix_sum_mod[i] < 0) {
             suffix_sum_mod[i] %= MOD_MINUS_1;
             if (suffix_sum_mod[i] < 0) suffix_sum_mod[i] += MOD_MINUS_1; // Ensure non-negative
        }
    }

    long long P_pow = 1; // Initialize product P_pow
    for (int i = 0; i < N - 1; ++i) { // Loop i from 0 to N-2 (covers pairs (i, j) where i fixed, j=i+1..N-1)
        long long base = A[i];
        // Base cannot be 0 since A_i >= 1.
        
        // Exponent is S_{i+1} mod MOD-1
        long long exponent_val_mod = suffix_sum_mod[i + 1]; 
        long long E = exponent_val_mod; // The effective exponent modulo MOD-1
        // The power function handles E=0 case correctly (returns 1)
        
        long long term_pow = power(base, E); // Calculate base^E mod MOD
        P_pow = (P_pow * term_pow) % MOD; // Multiply into running product P_pow
    }
     // Ensure P_pow is positive
    if (P_pow < 0) P_pow += MOD;


    // Part 3: Compute total product P = P_sum * P_pow mod MOD
    long long P_total = (P_sum * P_pow) % MOD;
    if (P_total < 0) P_total += MOD; // Ensure positive result

    // Part 4: Compute V_min = min_{i<j} (A_i + A_j) * A_i^(A_j)
    // Analysis showed V_min = (A_1 + A_2) * A_1^(A_2) after sorting A.
    std::vector<long long> A_sorted = A; // Create a copy to sort
    std::sort(A_sorted.begin(), A_sorted.end());
    long long A1 = A_sorted[0]; // Smallest element
    long long A2 = A_sorted[1]; // Second smallest element (N>=2 guaranteed)
    
    long long V_min_val;
    long long exponent_val = A2; // The exponent in V_{1,2} is A_2
    long long E_vmin = exponent_val % MOD_MINUS_1; // Exponent modulo MOD-1
    // E_vmin >= 0 since A2 >= 1.

    long long A1_pow_A2 = power(A1, E_vmin); // Calculate A1^E_vmin mod MOD
    
    V_min_val = (A1 + A2); // Calculate (A1+A2)
    V_min_val %= MOD; // Take modulo
    if (V_min_val < 0) V_min_val += MOD; // Ensure positive
    
    V_min_val = (V_min_val * A1_pow_A2) % MOD; // Final V_min value modulo MOD
    if (V_min_val < 0) V_min_val += MOD; // Ensure positive result

    // Part 5: Compute final result M = P / V_min mod MOD = P * V_min_inv mod MOD
    long long V_min_inv = modInverse(V_min_val); // Compute modular inverse of V_min_val

    long long M = (P_total * V_min_inv) % MOD; // Calculate M = P * V_min^{-1}
    if (M < 0) M += MOD; // Ensure final result is in [0, MOD-1]

    std::cout << M << std::endl; // Output the final result

    return 0;
}
0