結果
| 問題 |
No.856 増える演算
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:08:10 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 11,100 bytes |
| コンパイル時間 | 986 ms |
| コンパイル使用メモリ | 108,736 KB |
| 実行使用メモリ | 22,176 KB |
| 最終ジャッジ日時 | 2025-05-14 13:09:59 |
| 合計ジャッジ時間 | 9,006 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 56 WA * 24 |
ソースコード
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <complex>
#include <cmath> // For round, acos
// Define MOD constant
const long long MOD = 1000000007;
// Define MOD-1 constant, used for exponents in modular exponentiation via Fermat's Little Theorem
const long long MOD_MINUS_1 = 1000000006;
// Modular exponentiation: computes base^exp % MOD
// Handles base correctly modulo MOD. Handles exponent using Fermat's Little Theorem property.
// Assumes exp >= 0. Base can be anything.
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
// Ensure base is non-negative after initial modulo
if (base < 0) base += MOD;
// Handle 0^0 = 1. If base is 0 and exp is 0, result is 1.
// If base is 0 and exp > 0, result is 0.
if (base == 0) return (exp == 0) ? 1 : 0;
// The exponent used in calculation is exp % MOD_MINUS_1 based on Fermat's Little Theorem.
// However, if exp is a multiple of MOD_MINUS_1, exp % MOD_MINUS_1 = 0.
// In this case, the effective exponent should be MOD_MINUS_1 if exp > 0.
// If exp = 0, effective exponent is 0. power(base, 0) correctly yields 1.
// This function takes `exp` already processed. So we compute base^exp mod MOD directly.
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
// Modular inverse using Fermat's Little Theorem: computes n^(MOD-2) % MOD
// Assumes n is not divisible by MOD.
long long modInverse(long long n) {
n %= MOD;
if (n < 0) n += MOD;
// Inverse of 0 doesn't exist. Problem constraints A_i >= 1 ensure V_min > 0.
// We also proved V_min % MOD != 0.
if (n == 0) {
// This indicates an error state. Return 1 as a fallback might hide issues.
// Let's trust the analysis that n != 0.
// Consider throwing exception or assert if necessary for debugging.
return 1; // Placeholder, potentially problematic
}
// Use Fermat's Little Theorem: n^(MOD-2) is the inverse of n mod MOD
return power(n, MOD - 2);
}
// FFT implementation using complex numbers
const double PI = acos(-1.0);
// Performs FFT or Inverse FFT based on `invert` flag
void fft(std::vector<std::complex<double>>& a, bool invert) {
int n = a.size();
if (n == 1) return; // Base case for recursion
// Bit-reversal permutation could be implemented here for iterative FFT.
// Using simpler recursive version: split into even and odd indexed elements
std::vector<std::complex<double>> a0(n / 2), a1(n / 2);
for (int i = 0; 2 * i < n; i++) {
a0[i] = a[2 * i];
a1[i] = a[2 * i + 1];
}
// Recursively compute FFT for halves
fft(a0, invert);
fft(a1, invert);
// Combine results using butterfly operations
double ang = 2 * PI / n * (invert ? -1 : 1); // Angle for roots of unity depends on direction
std::complex<double> w(1), wn(cos(ang), sin(ang)); // w is current root of unity, wn is the primitive nth root
for (int i = 0; 2 * i < n; i++) {
std::complex<double> t = w * a1[i]; // Multiply odd part result by root of unity
a[i] = a0[i] + t; // Combine even and odd parts
a[i + n / 2] = a0[i] - t; // Combine even and odd parts
// For Inverse FFT, divide by 2 at each stage (equivalent to dividing by N at the end)
if (invert) {
a[i] /= 2;
a[i + n / 2] /= 2;
}
w *= wn; // Move to the next root of unity
}
}
// Polynomial multiplication using FFT. Returns coefficients of A(x) * B(x).
std::vector<long long> multiply(const std::vector<long long>& a, const std::vector<long long>& b) {
// Convert coefficient vectors to complex vectors
std::vector<std::complex<double>> fa(a.size()), fb(b.size());
for(size_t i=0; i<a.size(); ++i) fa[i] = a[i];
for(size_t i=0; i<b.size(); ++i) fb[i] = b[i];
// Determine FFT size (power of 2 >= result degree + 1)
int n = 1;
// Result degree is (a.size()-1) + (b.size()-1). Result length is sum of degrees + 1 = a.size() + b.size() - 1.
while (n < a.size() + b.size() - 1) n <<= 1;
// Resize vectors to FFT size, padding with zeros
fa.resize(n);
fb.resize(n);
// Compute FFT of both polynomials
fft(fa, false);
fft(fb, false);
// Pointwise multiply in frequency domain
for (int i = 0; i < n; i++) fa[i] *= fb[i];
// Compute Inverse FFT to get result polynomial coefficients
fft(fa, true);
// The result polynomial has length a.size() + b.size() - 1
std::vector<long long> result(a.size() + b.size() - 1);
for (size_t i = 0; i < result.size(); i++) {
// Ensure index i is within bounds of fa (which has size n)
if (i < fa.size()) {
// Convert complex result back to long long, rounding to handle potential precision errors
result[i] = static_cast<long long>(round(fa[i].real()));
} else {
result[i] = 0; // Should not happen if n calculation is correct, but safe fallback
}
}
return result;
}
int main() {
std::ios_base::sync_with_stdio(false); // Faster I/O
std::cin.tie(NULL);
int N; // Number of integers
std::cin >> N;
std::vector<long long> A(N); // Input array A
int V_max = 0; // Maximum value in A, needed for polynomial size
for (int i = 0; i < N; ++i) {
std::cin >> A[i];
if (A[i] > V_max) {
V_max = A[i];
}
}
// Compute frequency counts of values in A
std::vector<long long> counts(V_max + 1, 0);
for (long long val : A) {
// Check if value is within expected range [1, V_max]
if (val >= 1 && val <= V_max) {
counts[val]++;
}
}
// Part 1: Compute P_sum = product of (A_i + A_j) for i < j
// Construct polynomial P(x) = sum counts[v] * x^v. The degree is V_max. Length is V_max+1.
std::vector<long long> P_poly(V_max + 1, 0);
for (int v = 1; v <= V_max; ++v) {
P_poly[v] = counts[v];
}
// Compute P(x)^2 using FFT based polynomial multiplication
// The result contains coefficients N''_S for x^S, where N''_S = |{(i, j) | A_i + A_j = S}|
std::vector<long long> P_poly_sq_coeffs = multiply(P_poly, P_poly);
long long P_sum = 1; // Initialize product P_sum
// The result P_poly_sq_coeffs has indices up to 2 * V_max. Its size is V_max+1 + V_max+1 - 1 = 2*V_max + 1.
int max_S = P_poly_sq_coeffs.size() - 1; // Maximum index S in the result vector
// Iterate through possible sums S from 2 up to 2*V_max
for (int S = 2; S <= max_S; ++S) {
// Check bounds and if coefficient is zero
if (S >= P_poly_sq_coeffs.size() || P_poly_sq_coeffs[S] == 0) continue;
long long N_double_prime_S = P_poly_sq_coeffs[S]; // Coefficient of x^S in P(x)^2
// Calculate N_S = |{(i, j) | i < j, A_i + A_j = S}|
// Use formula 2 * N_S = N''_S - |{(i, i) | A_i + A_i = S}|
long long term_cnt = N_double_prime_S;
if (S % 2 == 0) { // Check diagonal contribution only if S is even
int v = S / 2;
// Check if v is a valid value within the range [1, V_max]
if (v >= 1 && v <= V_max) {
term_cnt -= counts[v]; // Subtract count of elements equal to v = S/2
}
}
// term_cnt = 2 * N_S. It's guaranteed to be even.
long long N_S = term_cnt / 2;
if (N_S == 0) continue; // If count N_S is 0, S^0 = 1, doesn't affect product
// Compute S^(N_S) mod MOD. The exponent needs to be taken modulo MOD-1.
long long exponent = N_S;
long long E = exponent % MOD_MINUS_1;
// E will be >= 0 since exponent N_S >= 0.
long long S_pow_E_S = power(S, E); // Calculate S^E mod MOD
P_sum = (P_sum * S_pow_E_S) % MOD; // Multiply into running product P_sum
}
// Ensure P_sum is positive
if (P_sum < 0) P_sum += MOD;
// Part 2: Compute P_pow = product of A_i^(S_{i+1}) for i=0..N-2
// S_{i+1} = sum(A_j for j=i+1..N-1)
// Compute suffix sums of A modulo MOD-1
std::vector<long long> suffix_sum_mod(N + 1, 0);
for (int i = N - 1; i >= 0; --i) {
// Calculate sum carefully to handle potential intermediate overflow if N*V_max > 2^63
// N*V_max <= 10^5 * 10^5 = 10^10, fits in long long.
suffix_sum_mod[i] = (suffix_sum_mod[i+1] + A[i]);
// Perform modulo operation
if (suffix_sum_mod[i] >= MOD_MINUS_1 || suffix_sum_mod[i] < 0) {
suffix_sum_mod[i] %= MOD_MINUS_1;
if (suffix_sum_mod[i] < 0) suffix_sum_mod[i] += MOD_MINUS_1; // Ensure non-negative
}
}
long long P_pow = 1; // Initialize product P_pow
for (int i = 0; i < N - 1; ++i) { // Loop i from 0 to N-2 (covers pairs (i, j) where i fixed, j=i+1..N-1)
long long base = A[i];
// Base cannot be 0 since A_i >= 1.
// Exponent is S_{i+1} mod MOD-1
long long exponent_val_mod = suffix_sum_mod[i + 1];
long long E = exponent_val_mod; // The effective exponent modulo MOD-1
// The power function handles E=0 case correctly (returns 1)
long long term_pow = power(base, E); // Calculate base^E mod MOD
P_pow = (P_pow * term_pow) % MOD; // Multiply into running product P_pow
}
// Ensure P_pow is positive
if (P_pow < 0) P_pow += MOD;
// Part 3: Compute total product P = P_sum * P_pow mod MOD
long long P_total = (P_sum * P_pow) % MOD;
if (P_total < 0) P_total += MOD; // Ensure positive result
// Part 4: Compute V_min = min_{i<j} (A_i + A_j) * A_i^(A_j)
// Analysis showed V_min = (A_1 + A_2) * A_1^(A_2) after sorting A.
std::vector<long long> A_sorted = A; // Create a copy to sort
std::sort(A_sorted.begin(), A_sorted.end());
long long A1 = A_sorted[0]; // Smallest element
long long A2 = A_sorted[1]; // Second smallest element (N>=2 guaranteed)
long long V_min_val;
long long exponent_val = A2; // The exponent in V_{1,2} is A_2
long long E_vmin = exponent_val % MOD_MINUS_1; // Exponent modulo MOD-1
// E_vmin >= 0 since A2 >= 1.
long long A1_pow_A2 = power(A1, E_vmin); // Calculate A1^E_vmin mod MOD
V_min_val = (A1 + A2); // Calculate (A1+A2)
V_min_val %= MOD; // Take modulo
if (V_min_val < 0) V_min_val += MOD; // Ensure positive
V_min_val = (V_min_val * A1_pow_A2) % MOD; // Final V_min value modulo MOD
if (V_min_val < 0) V_min_val += MOD; // Ensure positive result
// Part 5: Compute final result M = P / V_min mod MOD = P * V_min_inv mod MOD
long long V_min_inv = modInverse(V_min_val); // Compute modular inverse of V_min_val
long long M = (P_total * V_min_inv) % MOD; // Calculate M = P * V_min^{-1}
if (M < 0) M += MOD; // Ensure final result is in [0, MOD-1]
std::cout << M << std::endl; // Output the final result
return 0;
}
qwewe