結果

問題 No.1391 ±1 Abs Sum
ユーザー qwewe
提出日時 2025-05-14 13:15:41
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 456 ms / 2,000 ms
コード長 8,230 bytes
コンパイル時間 931 ms
コンパイル使用メモリ 98,376 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-05-14 13:17:21
合計ジャッジ時間 9,777 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 34
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iterator> // Required for std::distance
#include <limits> // Required for std::numeric_limits

using namespace std;

// Define long long for convenience and to handle large values
typedef long long ll;

// Global vector for prefix sums. P[k] stores sum A[0] + ... + A[k-1].
vector<ll> P;

/**
 * @brief Calculates the sum of elements A[i] in the range [L, R] (0-based indices).
 * Uses the precomputed prefix sum array P.
 * P[k] stores the sum of the first k elements (A[0]...A[k-1]).
 * The sum of A[L...R] is P[R+1] - P[L].
 * @param L The starting index of the range (inclusive).
 * @param R The ending index of the range (inclusive).
 * @return The sum of elements A[i] for L <= i <= R. Returns 0 for empty ranges.
 */
ll sum_A(int L, int R) {
    // Check for empty range.
    if (L > R) return 0; 
    // Bounds check to ensure valid access to prefix sum array P.
    // P has size N+1, valid indices are 0 to N.
    // P[R+1] requires R+1 <= N, so R <= N-1.
    // P[L] requires L >= 0.
    // These indices should be valid in the context of the problem.
    return P[R + 1] - P[L];
}

/**
 * @brief Calculates the sum of absolute differences |A[j] - A[i]| for i in the range [L, R] (0-based indices).
 * Exploits the fact that A is sorted and uses prefix sums for efficiency.
 * @param j The index of the fixed element A[j].
 * @param L The starting index of the range for i (inclusive).
 * @param R The ending index of the range for i (inclusive).
 * @param A The input array (non-decreasing).
 * @return The sum \sum_{i=L}^{R} |A[j] - A[i]|. Returns 0 for empty ranges or invalid indices.
 */
ll sum_abs_diff(int j, int L, int R, const vector<ll>& A) {
    int N = A.size();
    // Check for empty array or empty range.
    if (N == 0 || L > R) return 0; 
    
    // Clamp indices L and R to the valid range [0, N-1].
    L = max(0, L);
    R = min(N - 1, R);
    // If after clamping, L > R, the effective range is empty.
    if (L > R) return 0; 

    // Check if index j is valid.
    if (j < 0 || j >= N) return 0; // Index j out of bounds, return 0 or handle as error.

    ll total_sum = 0;
    
    // Calculate sum for i <= j: These terms are A[j] - A[i].
    // The relevant indices are in the intersection of [L, R] and (-inf, j], which is [L, min(j, R)].
    int left_end = min(j, R);
    if (left_end >= L) { // Check if this range is non-empty.
        ll count = left_end - L + 1; // Number of elements in range [L, left_end].
        // Sum = count * A[j] - (sum of A[i] for i in [L, left_end])
        total_sum += count * A[j] - sum_A(L, left_end);
    }
    
    // Calculate sum for i > j: These terms are A[i] - A[j].
    // The relevant indices are in the intersection of [L, R] and (j, +inf), which is [max(j + 1, L), R].
    int right_start = max(j + 1, L);
    if (right_start <= R) { // Check if this range is non-empty.
        ll count = R - right_start + 1; // Number of elements in range [right_start, R].
        // Sum = (sum of A[i] for i in [right_start, R]) - count * A[j]
        total_sum += sum_A(right_start, R) - count * A[j];
    }
    
    return total_sum;
}


int main() {
    // Use faster I/O operations.
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int N;
    ll K_ll; // Use long long for K initially as problem constraints allow large N.
    cin >> N >> K_ll;
    int K = (int)K_ll; // Convert K to int. K <= N <= 2e5, so K fits in int.

    vector<ll> A(N);
    for (int i = 0; i < N; ++i) {
        cin >> A[i];
    }

    // Handle the edge case where N=0.
    if (N == 0) {
        cout << 0 << endl;
        return 0;
    }
    
    // Precompute prefix sums of the array A. P has size N+1.
    P.resize(N + 1, 0);
    for (int i = 0; i < N; ++i) {
        P[i + 1] = P[i] + A[i];
    }

    // Initialize the minimum total value found so far to the maximum possible long long value.
    ll min_total_val = numeric_limits<ll>::max(); 

    // Iterate through each element A[j] as a potential location x for the minimum value of f_B(x).
    // The minimum of f_B(x) on [A_1, A_N] must occur at some A_j.
    for (int j = 0; j < N; ++j) { 
        
        ll D_K = 0; // Stores the K-th smallest distance |A[j] - A[i]|. Initialized to 0.
        
        // If K > 0, we need to find the K-th smallest distance.
        // If K = 0, the set of indices with B_i=1 is empty, M_j=0.
        if (K > 0) {
            // Binary search for the K-th smallest distance D_K.
            ll low = 0, high; 
            // Determine a safe upper bound for distances. Max distance from A[j] is to A[0] or A[N-1].
            if (N == 1) high = 0;
            else high = max(abs(A[j] - A[0]), abs(A[j] - A[N-1]));
            // Add a small buffer to high just in case, though max difference should cover it. high = high + 1? Not really necessary as distances are non-negative.
            
            // Perform binary search on the distance value.
            while (low <= high) {
                ll mid = low + (high - low) / 2; // Calculate midpoint avoiding potential overflow.
                
                // Count elements A[i] such that |A[j] - A[i]| <= mid.
                // This condition is equivalent to A[i] being in the range [A[j] - mid, A[j] + mid].
                // lower_bound finds the first element >= A[j] - mid.
                auto it_p = lower_bound(A.begin(), A.end(), A[j] - mid);
                // upper_bound finds the first element > A[j] + mid.
                auto it_q = upper_bound(A.begin(), A.end(), A[j] + mid);
                // The number of elements in the range [it_p, it_q) is the count.
                ll count = distance(it_p, it_q);

                // If count is at least K, it means mid might be D_K or larger than D_K.
                // We try smaller distances to find the minimum D_K satisfying the condition.
                if (count >= K) {
                    D_K = mid; // Update D_K potential value.
                    high = mid - 1; // Search in the lower half [low, mid-1].
                } else {
                    // If count < K, the distance mid is too small. Need larger distance.
                    low = mid + 1; // Search in the upper half [mid+1, high].
                }
            }
        }

        // Calculate M_j = sum of the K smallest distances |A[j] - A[i]|.
        ll C_lessK = 0; // Count of points with distance strictly less than D_K.
        ll M_lessK = 0; // Sum of distances for points with distance strictly less than D_K.
        
        // Calculate C_lessK and M_lessK only if K > 0 and D_K > 0.
        // If D_K = 0, distances < D_K is impossible, so C_lessK = 0 and M_lessK = 0.
        // If K = 0, M_j = 0 anyway.
        if (K > 0 && D_K > 0) {
             ll D_less = D_K - 1; // The maximum distance strictly less than D_K.
             // Find points A[i] such that |A[j] - A[i]| <= D_less.
             auto it_p_less = lower_bound(A.begin(), A.end(), A[j] - D_less);
             auto it_q_less = upper_bound(A.begin(), A.end(), A[j] + D_less);
             C_lessK = distance(it_p_less, it_q_less);
             
             // Calculate the sum of distances for these points.
             int p_idx = distance(A.begin(), it_p_less); // 0-based index of first element.
             int q_idx = distance(A.begin(), it_q_less) - 1; // 0-based index of last element.
             M_lessK = sum_abs_diff(j, p_idx, q_idx, A);
        }
        
        // The total sum M_j consists of sum for distances < D_K, plus (K - C_lessK) items each contributing D_K.
        ll M_j = M_lessK + (K - C_lessK) * D_K;
        
        // Calculate T_j = total sum of distances |A[j] - A[i]| for all i from 0 to N-1.
        ll T_j = sum_abs_diff(j, 0, N - 1, A);

        // Calculate V_j = 2 * M_j - T_j. This is the minimum value of f_B(x) at x=A[j] for the optimal B for this j.
        ll current_Vj = 2 * M_j - T_j;

        // Update the overall minimum value found across all j.
        min_total_val = min(min_total_val, current_Vj);
    }

    // Output the final minimum value found.
    cout << min_total_val << endl;

    return 0;
}
0