結果

問題 No.1240 Or Sum of Xor Pair
ユーザー qwewe
提出日時 2025-05-14 13:10:28
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 11,218 bytes
コンパイル時間 765 ms
コンパイル使用メモリ 76,780 KB
実行使用メモリ 47,744 KB
最終ジャッジ日時 2025-05-14 13:11:49
合計ジャッジ時間 9,887 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 29 WA * 1
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <numeric>

// Define the number of bits based on constraints. 
// The maximum value of A_i is less than 2^18.
// The maximum value of X is 2^18.
// We need to handle values up to 2^18 - 1, which require 18 bits (indexed 0 to 17).
const int B = 18; 

// Trie Node structure
struct TrieNode {
    TrieNode* children[2]; // Pointers to child nodes for bit 0 and 1
    int count;             // Number of elements whose path passes through this node and ends in its subtree
    // bit_sum[p] stores the count of elements in this node's subtree that have the p-th bit set.
    // Using int (typically 32-bit) is sufficient as max count N=2e5 fits.
    int bit_sum[B];        

    // Constructor initializes node fields
    TrieNode() : count(0) {
        children[0] = children[1] = nullptr;
        // Initialize bit_sum array elements to 0
        for(int i=0; i<B; ++i) bit_sum[i] = 0;
    }

    // Optional Destructor to free memory - generally omitted in competitive programming for speed
    // unless memory limits are very strict or multiple test cases run in one process.
    // ~TrieNode() {
    //     delete children[0];
    //     delete children[1];
    // }
};

// Function to insert a value into the trie
// It updates counts and bit sums along the path from the root
void insert(TrieNode* root, int val) {
    TrieNode* curr = root;
    // Traverse bits from most significant (B-1) down to least significant (0)
    for (int k = B - 1; k >= 0; --k) {
        // Increment count for the current node, indicating one more element passes through here
        curr->count++;
        // Update bit sums for the current node based on the bits of 'val'
        for (int p = 0; p < B; ++p) {
            if ((val >> p) & 1) { // Check if p-th bit of val is set
                curr->bit_sum[p]++;
            }
        }
        
        // Determine the next bit (0 or 1) in the path for 'val'
        int bit = (val >> k) & 1;
        // If the required child node doesn't exist, create it
        if (!curr->children[bit]) {
            curr->children[bit] = new TrieNode();
        }
        // Move to the appropriate child node
        curr = curr->children[bit];
    }
    // After the loop, 'curr' points to the node corresponding to the full path for 'val'.
    // Update this final node's count and bit sums as well.
    curr->count++;
     for (int p = 0; p < B; ++p) {
        if ((val >> p) & 1) {
            curr->bit_sum[p]++;
        }
    }
}

// Helper function to calculate SUM (val | A_j) for all A_j in the subtree rooted at 'node'
// This is efficiently computed using precalculated counts and bit sums.
// Called when an entire subtree satisfies the XOR condition relative to 'val' and 'X'.
long long calculate_or_sum(int val, TrieNode* node) {
    // If node is null or represents an empty subtree, the sum contribution is 0
    if (!node || node->count == 0) {
        return 0;
    }
    long long current_sum = 0;
    // Calculate the total sum bit by bit from p=0 to B-1
    for (int p = 0; p < B; ++p) {
        long long term_count;
        // Determine how many elements A_j in the subtree contribute 2^p to the sum for bit p.
        // This depends on the p-th bit of 'val'.
        if ((val >> p) & 1) {
            // If val's p-th bit is 1, then (val | A_j)'s p-th bit is always 1, regardless of A_j's p-th bit.
            // Contribution comes from all elements in the subtree.
            term_count = node->count;
        } else {
            // If val's p-th bit is 0, then (val | A_j)'s p-th bit is 1 iff A_j's p-th bit is 1.
            // Contribution comes only from elements with p-th bit set.
            term_count = node->bit_sum[p];
        }
        // Add the contribution for bit p: term_count * 2^p
        // Use 1LL to ensure the multiplication is done using 64-bit integers to prevent overflow.
        current_sum += term_count * (1LL << p);
    }
    return current_sum;
}

// Structure to bundle the results from the query function:
// count of pairs found and the total sum of their OR values.
struct QueryResult {
    long long count;          // Number of valid pairs (A_i, A_j) found for a given A_i ('val')
    long long total_or_sum;   // Sum of (A_i | A_j) for these pairs
};

// Recursive function to query the trie.
// Finds all A_j (already inserted in the trie) such that (val ^ A_j < X) 
// and computes the sum of (val | A_j) for these pairs.
// 'node': current node in the trie traversal
// 'val': the value A_i we are currently processing and matching against trie elements
// 'k': current bit position being considered (starts at B-1, goes down to -1)
// 'X': the upper bound for the XOR value (A_i ^ A_j)
QueryResult query_sum(TrieNode* node, int val, int k, int X) {
    // Base cases for recursion termination:
    // If the current node is null (path doesn't exist), 
    // or the subtree rooted here is empty (count is 0),
    // or we have processed all bits (k < 0).
    if (!node || node->count == 0 || k < 0) {
        return {0, 0}; // Return zero count and sum
    }

    QueryResult result = {0, 0}; // Initialize result for the current recursive call
    int val_k = (val >> k) & 1;  // k-th bit of val (A_i)
    int X_k = (X >> k) & 1;      // k-th bit of the limit X

    TrieNode* child0 = node->children[0]; // Child node corresponding to bit 0
    TrieNode* child1 = node->children[1]; // Child node corresponding to bit 1

    // The logic branches based on X_k, the k-th bit of the limit X
    if (X_k == 1) {
        // If X_k is 1, the condition (val ^ Aj < X) can be satisfied in two ways at bit k:
        // 1. If the k-th bit of (val ^ Aj) is 0: The XOR is strictly smaller than X up to this bit.
        //    All elements Aj in the corresponding subtree satisfy the condition. We add their full contribution.
        // 2. If the k-th bit of (val ^ Aj) is 1: The XOR matches X up to this bit.
        //    We need to continue checking the lower bits recursively.
        
        if (val_k == 0) { // k-th bit of val is 0
            // Path 1: Consider Aj in child0 (where Aj[k] = 0). XOR bit is 0 ^ 0 = 0.
            // Since 0 < X_k = 1, all Aj in child0 satisfy the condition up to bit k.
            if (child0) {
                 long long sum0 = calculate_or_sum(val, child0); // Compute sum for the entire child0 subtree
                 result.count += child0->count;
                 result.total_or_sum += sum0;
            }
            // Path 2: Consider Aj in child1 (where Aj[k] = 1). XOR bit is 0 ^ 1 = 1.
            // Since 1 = X_k, we must check lower bits. Recurse on child1.
            if(child1) { // Check if child1 exists before recursing
               QueryResult res_rec = query_sum(child1, val, k - 1, X);
               result.count += res_rec.count;
               result.total_or_sum += res_rec.total_or_sum;
            }
        } else { // k-th bit of val is 1
             // Path 1: Consider Aj in child1 (where Aj[k] = 1). XOR bit is 1 ^ 1 = 0.
             // Since 0 < X_k = 1, all Aj in child1 satisfy the condition.
             if (child1) {
                 long long sum1 = calculate_or_sum(val, child1); // Compute sum for the entire child1 subtree
                 result.count += child1->count;
                 result.total_or_sum += sum1;
             }
            // Path 2: Consider Aj in child0 (where Aj[k] = 0). XOR bit is 1 ^ 0 = 1.
            // Since 1 = X_k, we must check lower bits. Recurse on child0.
             if(child0) { // Check if child0 exists before recursing
                QueryResult res_rec = query_sum(child0, val, k - 1, X);
                result.count += res_rec.count;
                result.total_or_sum += res_rec.total_or_sum;
             }
        }
    } else { // X_k == 0
        // If X_k is 0, the condition (val ^ Aj < X) can only be potentially satisfied if:
        // The k-th bit of (val ^ Aj) is 0. Then it matches X's k-th bit, and we must check lower bits recursively.
        // If the k-th bit of (val ^ Aj) is 1, then (val ^ Aj) is already >= X, so the condition is violated.
        
        if (val_k == 0) { // k-th bit of val is 0
            // Path 1: Consider Aj in child0 (where Aj[k] = 0). XOR bit is 0 ^ 0 = 0.
            // Since 0 = X_k, we must check lower bits. Recurse on child0.
             if(child0) { // Check if child0 exists before recursing
                QueryResult res_rec = query_sum(child0, val, k - 1, X);
                result.count += res_rec.count;
                result.total_or_sum += res_rec.total_or_sum;
             }
            // Path 2: Consider Aj in child1 (where Aj[k] = 1). XOR bit is 0 ^ 1 = 1.
            // Since 1 > X_k = 0, the condition is violated. Ignore this path.
        } else { // k-th bit of val is 1
             // Path 1: Consider Aj in child1 (where Aj[k] = 1). XOR bit is 1 ^ 1 = 0.
             // Since 0 = X_k, we must check lower bits. Recurse on child1.
             if(child1) { // Check if child1 exists before recursing
                QueryResult res_rec = query_sum(child1, val, k - 1, X);
                result.count += res_rec.count;
                result.total_or_sum += res_rec.total_or_sum;
             }
            // Path 2: Consider Aj in child0 (where Aj[k] = 0). XOR bit is 1 ^ 0 = 1.
            // Since 1 > X_k = 0, the condition is violated. Ignore this path.
        }
    }
    
    return result; // Return the aggregated count and sum from this level
}

// Optional: Function to delete the entire trie to free allocated memory
void deleteTrie(TrieNode* node) {
    if (!node) return;
    deleteTrie(node->children[0]);
    deleteTrie(node->children[1]);
    delete node;
}

int main() {
    // Use fast I/O operations
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int N; // Number of integers
    int X; // XOR limit
    std::cin >> N >> X;

    std::vector<int> A(N); // Vector to store the input integers
    for (int i = 0; i < N; ++i) {
        std::cin >> A[i];
    }

    TrieNode* root = new TrieNode(); // Create the root node of the trie
    long long total_sum = 0; // Initialize total sum accumulator (use long long for potentially large sums)

    // Process each element A[i] of the input array
    for (int i = 0; i < N; ++i) {
        // Query the trie with A[i]. This finds all A[j] already in the trie (meaning j < i)
        // such that A[i] ^ A[j] < X. It also computes the sum of (A[i] | A[j]) for these pairs.
        QueryResult result = query_sum(root, A[i], B - 1, X);
        
        // Add the sum obtained from the query to the total sum
        total_sum += result.total_or_sum;
        
        // Insert A[i] into the trie. This makes it available for querying against subsequent elements A[k] where k > i.
        insert(root, A[i]);
    }

    // Output the final computed total sum
    std::cout << total_sum << std::endl;

    // Optional: Deallocate trie memory. Not strictly necessary for typical competitive programming contest submissions
    // as OS reclaims memory on program exit, but good practice in general software development.
    // deleteTrie(root); 
    
    return 0;
}
0