結果

問題 No.956 Number of Unbalanced
ユーザー qwewe
提出日時 2025-05-14 13:09:15
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 447 ms / 2,000 ms
コード長 11,160 bytes
コンパイル時間 1,379 ms
コンパイル使用メモリ 111,772 KB
実行使用メモリ 13,952 KB
最終ジャッジ日時 2025-05-14 13:10:20
合計ジャッジ時間 5,524 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 6
other AC * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <map>
#include <cmath>
#include <numeric>
#include <algorithm> // For std::max, std::min

// Fenwick tree (Binary Indexed Tree) implementation
// Template allows usage with different numeric types, like long long for counts
template <typename T>
struct FenwickTree {
    int size; // Size of the underlying array represented by the Fenwick tree
    std::vector<T> tree; // The Fenwick tree structure, 1-based indexing internally

    // Constructor: Initialize Fenwick tree of size sz
    // The tree will represent an array of indices 0 to sz-1.
    // The internal tree structure uses indices 1 to sz.
    FenwickTree(int sz) : size(sz), tree(sz + 1, 0) {}

    // Add delta to the value at index idx (0-based)
    void update(int idx, T delta) {
         // Check if the index is within the valid range [0, size-1]
         if (idx < 0 || idx >= size) {
             // Index out of bounds, could throw error or ignore
             return; 
         }
        // Convert 0-based index to 1-based index for internal Fenwick tree operations
        for (++idx; idx <= size; idx += idx & -idx) {
            tree[idx] += delta;
        }
    }

    // Query the prefix sum up to index idx (0-based), i.e., sum of elements from index 0 to idx.
    T query(int idx) {
        // If idx is negative, the sum is 0
        if (idx < 0) return 0;
        // Clamp index to the maximum valid index if it exceeds bounds
        idx = std::min(idx, size - 1); 
        
        T sum = 0;
        // Convert 0-based index to 1-based index for internal Fenwick tree query
        for (++idx; idx > 0; idx -= idx & -idx) {
            sum += tree[idx];
        }
        return sum;
    }
    
    // Query sum in the range [0, idx-1], which means summing the first idx elements (indices 0 to idx-1)
    T query_prefix(int idx) {
        // If idx is 0 or less, the range is empty, so the sum is 0
        if (idx <= 0) return 0;
        // Otherwise query the sum up to index idx-1
        return query(idx - 1);
    }
};


int main() {
    // Use faster I/O operations
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int N; // Number of elements in the sequence
    std::cin >> N;

    std::vector<int> A(N); // The input sequence
    // Store positions (0-based indices) for each distinct value in the sequence
    std::map<int, std::vector<int>> positions; 
    // Store counts of each distinct value
    std::map<int, int> counts; 
    // Read the input sequence and populate positions and counts maps
    for (int i = 0; i < N; ++i) {
        std::cin >> A[i];
        positions[A[i]].push_back(i);
        counts[A[i]]++;
    }

    // Variable to store the total count of unbalanced subarrays
    long long total_unbalanced_count = 0;
    
    // Determine the threshold B for classifying values as frequent or infrequent
    // A common choice is sqrt(N)
    int B = 0; 
    if (N > 0) { // Avoid potential issues with N=0
       B = static_cast<int>(sqrt(N));
    }
    // The threshold must be at least 1
    if (B == 0) B = 1; 

    // Store distinct values present in the array
    std::vector<int> distinct_values;
    for(auto const& [val, count] : counts) {
        distinct_values.push_back(val);
    }

    // Iterate through each distinct value found in the input sequence
    for (int v : distinct_values) {
        int k_v = counts[v]; // Get the count of the current value v
        
        // Should not happen with map iteration, but safeguard
        if (k_v == 0) continue; 

        // Check if the value is frequent (count > B)
        if (k_v > B) { // Frequent value logic
            // Construct the B^{(v)} sequence transformed into prefix sums P^{(v)}
            // P[k] = sum_{m=1}^k B^{(v)}_m where B^{(v)}_m = 1 if A[m-1]=v else -1
            std::vector<int> P(N + 1, 0);
            for (int i = 0; i < N; ++i) {
                P[i+1] = P[i] + (A[i] == v ? 1 : -1);
            }

            // The values P[k] range from -N to N. Total 2N+1 possible values.
            // Shift these values by N to map them to non-negative range [0, 2N].
            // The Fenwick tree needs size 2N+1 to cover indices 0 to 2N.
            // Use long long for counts stored in FT to avoid overflow
            FenwickTree<long long> ft(2 * N + 1); 
            
            long long current_v_count = 0; // Count of unbalanced subarrays with v as majority
            // Initialize Fenwick tree with the contribution of P[0] = 0.
            // Shifted index for P[0]=0 is N.
            ft.update(P[0] + N, 1); 

            // Iterate through the end index j of subarrays A[i..j]
            for (int j = 1; j <= N; ++j) {
                // An subarray A[i..j] has v as majority if P[j] - P[i-1] > 0, i.e. P[j] > P[i-1].
                // We need to count the number of indices k=i-1 in [0, j-1] such that P[k] < P[j].
                // In shifted coordinates, this is P[k]+N < P[j]+N.
                // This count is obtained by querying the sum of counts for indices < P[j]+N in the Fenwick tree.
                 current_v_count += ft.query_prefix(P[j] + N);
                 // Add the count for the current prefix sum P[j]. Its shifted index is P[j]+N.
                 ft.update(P[j] + N, 1); 
            }
            // Add the count for this frequent value v to the total count
            total_unbalanced_count += current_v_count;

        } else { // Infrequent value logic (count <= B)
            const std::vector<int>& pos_v = positions[v]; // Get 0-based indices of occurrences for v
            long long current_v_count = 0; // Count for current infrequent value v
            
            // Create an extended list of positions including boundary sentinels
            // p_ext stores 0-based indices for occurrences of v.
            std::vector<int> p_ext;
            p_ext.push_back(-1); // Sentinel p_0 = -1
            for(int p : pos_v) p_ext.push_back(p); // Add actual occurrences
            p_ext.push_back(N); // Sentinel p_{k_v+1} = N

            // Iterate through each occurrence m (1-based index in p_ext list)
            for (int m = 1; m <= k_v; ++m) { 
                // Iterate through possible counts 'c' of value v in the subarray ending at p_m
                // The subarray must contain exactly c occurrences: p_k, ..., p_m
                for (int c = 1; c <= m; ++c) { 
                    // k is the 1-based index of the first occurrence in this sequence of c occurrences
                    int k = m - c + 1; 
                    
                    // Get 0-based indices from p_ext using 1-based indices k, m
                    int p_k = p_ext[k];          // Index of k-th occurrence
                    int p_k_minus_1 = p_ext[k-1]; // Index of (k-1)-th occurrence
                    int p_m = p_ext[m];          // Index of m-th occurrence
                    int p_m_plus_1 = p_ext[m+1]; // Index of (m+1)-th occurrence

                    // Define the valid range for the end index j of the subarray A[i..j].
                    // The subarray must contain p_m and not p_{m+1}. Thus j is in [p_m, p_{m+1}-1].
                    int j_start = p_m;
                    int j_end = p_m_plus_1 - 1;

                    // If the range for j is invalid (e.g., p_m >= p_{m+1}), skip
                    if (j_start > j_end) continue; 

                    // Calculate the total count of valid start indices i for all j in the range [j_start, j_end]
                    // A valid pair (i, j) must satisfy:
                    // 1. The subarray A[i..j] contains exactly c occurrences of v, which are p_k, ..., p_m.
                    //    This means p_{k-1} < i <= p_k and p_m <= j < p_{m+1}.
                    // 2. v is the majority element: c > (j - i + 1) / 2, which implies i > j + 1 - 2c.
                    // Combining conditions on i: i must be in [max(p_{k-1}+1, j+2-2c), p_k].
                    // The number of valid i for a fixed j is f(j) = max(0, p_k - max(p_{k-1}+1, j+2-2c) + 1).
                    // We need to sum f(j) for j in [j_start, j_end].
                    long long sum_f_j = 0;
                    
                    // Calculate the transition point j_T based on the inner max expression
                    // The expression changes at j + 2 - 2c = p_{k-1} + 1
                    // j_T = p_{k-1} + 2c - 1
                    int j_T = p_k_minus_1 + 2*c - 1;
                                                        
                    // Part 1: Case where j <= j_T
                    // Here max(p_{k-1}+1, j+2-2c) = p_{k-1}+1.
                    // f(j) = max(0, p_k - (p_{k-1}+1) + 1) = max(0, p_k - p_k_minus_1). This is constant.
                    int J1_start = j_start;
                    int J1_end = std::min(j_end, j_T);
                    if (J1_start <= J1_end) {
                         int count1 = std::max(0, p_k - p_k_minus_1); 
                         if (count1 > 0) { // Only add if positive count is possible
                            sum_f_j += (long long)count1 * (J1_end - J1_start + 1);
                         }
                    }

                    // Part 2: Case where j > j_T
                    // Here max(p_{k-1}+1, j+2-2c) = j+2-2c.
                    // f(j) = max(0, p_k - (j+2-2c) + 1) = max(0, p_k - j + 2*c - 1). Linear decreasing in j.
                    int J2_start = std::max(j_start, j_T + 1);
                    int J2_end = j_end;
                    
                    if (J2_start <= J2_end) {
                         // Sum f(j) = max(0, C - j) where C = p_k + 2*c - 1
                         // The term C - j is positive when j < C.
                         int C = p_k + 2*c - 1;
                         int j_zero_boundary = C; // Boundary where C - j becomes non-positive
                         // The summation range is [J2_start, min(J2_end, j_zero_boundary - 1)]
                         int J2_sum_end = std::min(J2_end, j_zero_boundary - 1);

                         if (J2_start <= J2_sum_end) {
                            // Sum the arithmetic progression (C - j) for j in [J2_start, J2_sum_end]
                            long long first_term = C - J2_start; // Value at j = J2_start
                            long long last_term = C - J2_sum_end; // Value at j = J2_sum_end
                            long long num_terms = J2_sum_end - J2_start + 1; // Number of terms
                            if (num_terms > 0) { // Ensure positive number of terms before summing
                                sum_f_j += num_terms * (first_term + last_term) / 2;
                            }
                         }
                    }
                    // Add the contribution for this (m, c) pair
                    current_v_count += sum_f_j;
                }
            }
            // Add the count for this infrequent value v to the total count
            total_unbalanced_count += current_v_count;
        }
    }

    // Output the final total count of unbalanced subarrays
    std::cout << total_unbalanced_count << std::endl;

    return 0;
}
0