結果

問題 No.924 紲星
ユーザー qwewe
提出日時 2025-05-14 12:59:49
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 9,965 bytes
コンパイル時間 145 ms
コンパイル使用メモリ 82,532 KB
実行使用メモリ 463,952 KB
最終ジャッジ日時 2025-05-14 13:01:45
合計ジャッジ時間 9,295 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 11 TLE * 5
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set higher recursion depth limit for safety, although log N is relatively small
# On some platforms like Yukicoder, increasing recursion limit might be necessary.
# If this causes issues or is disallowed, it might need removal or adjustment.
# sys.setrecursionlimit(200010) 

def solve():
    # Use faster input reading method
    readline = sys.stdin.readline

    N, Q = map(int, readline().split())
    A = list(map(int, readline().split()))

    # Coordinate Compression: Map original values to ranks 0 to U-1
    # Store unique sorted values to map ranks back to original values
    unique_vals = sorted(list(set(A)))
    val_to_rank = {val: i for i, val in enumerate(unique_vals)}
    rank_to_val = unique_vals
    U = len(unique_vals) # Number of unique values

    # Persistent Segment Tree Node definition
    class Node:
        # Use __slots__ to potentially reduce memory footprint compared to default dict
        __slots__ = ('count', 'total_sum', 'left', 'right') 
        
        def __init__(self, count=0, total_sum=0, left=None, right=None):
            self.count = count # Number of elements represented by this node
            self.total_sum = total_sum # Sum of values of elements represented by this node
            self.left = left # Pointer to left child node
            self.right = right # Pointer to right child node
    
    # roots[i] will store the root of the segment tree representing the prefix A[0...i-1]
    roots = [None] * (N + 1)
    
    # Update function for persistent segment tree (recursive)
    # Creates a new path of nodes from root to leaf for the updated element
    # Reuses unchanged subtrees from the previous version
    def update(node, low, high, target_rank, value):
        # Initialize current node's properties based on the previous version 'node'
        current_count = 0
        current_total_sum = 0
        current_left = None
        current_right = None
        
        if node is not None:
            # If previous version exists, copy its properties
            current_count = node.count
            current_total_sum = node.total_sum
            current_left = node.left
            current_right = node.right

        # Base case: reached a leaf node corresponding to the target rank
        if low == high:
            # Create a new leaf node with incremented count and added value
            return Node(current_count + 1, current_total_sum + value)

        # Recursive step: determine which child to update
        mid = (low + high) // 2
        
        new_left = current_left  # Initially point to old children
        new_right = current_right

        if target_rank <= mid:
            # Target rank is in the left half, update left child recursively
            new_left = update(current_left, low, mid, target_rank, value)
        else:
            # Target rank is in the right half, update right child recursively
            new_right = update(current_right, mid + 1, high, target_rank, value)
        
        # Calculate new aggregates for the current internal node based on its children
        # Handle cases where children might be None (representing 0 count/sum)
        new_count = (new_left.count if new_left else 0) + (new_right.count if new_right else 0)
        new_total_sum = (new_left.total_sum if new_left else 0) + (new_right.total_sum if new_right else 0)
        
        # Create and return the new internal node with updated aggregates and child pointers
        return Node(new_count, new_total_sum, new_left, new_right)

    # Build the persistent segment trees version by version
    # roots[0] is None (representing empty prefix)
    # roots[i+1] is built upon roots[i] by adding element A[i]
    if U > 0: # Proceed only if there are unique values (N >= 1 ensures U >= 1)
        for i in range(N):
            rank = val_to_rank[A[i]] # Get the rank of the current element A[i]
            # Create the next version of the tree
            roots[i+1] = update(roots[i], 0, U - 1, rank, A[i])
    # else: if U == 0 (i.e., N=0), this loop is skipped. roots remains all None.

    # Query function to find the rank of the k-th smallest element in the subarray A[L..R]
    # Uses two tree roots: roots[R] and roots[L-1] to query the range effectively
    def query_kth(node_R, node_L_1, low, high, k):
        # Base case: If we reached a leaf node, its index 'low' is the rank we seek
        if low == high:
            return low

        mid = (low + high) // 2
        
        # Calculate the count of elements in the left child's value range for the subarray A[L..R]
        # This is done by subtracting counts: count(tree R) - count(tree L-1)
        left_count_R = node_R.left.count if node_R and node_R.left else 0
        left_count_L_1 = node_L_1.left.count if node_L_1 and node_L_1.left else 0
        # The difference gives the number of elements in range A[L..R] that fall into the left child's value range
        left_count_diff = left_count_R - left_count_L_1

        if k <= left_count_diff:
            # The k-th element is within the left subtree's value range. Recurse left.
            next_node_R = node_R.left if node_R else None
            next_node_L_1 = node_L_1.left if node_L_1 else None
            return query_kth(next_node_R, next_node_L_1, low, mid, k)
        else:
            # The k-th element is within the right subtree's value range. Recurse right.
            next_node_R = node_R.right if node_R else None
            next_node_L_1 = node_L_1.right if node_L_1 else None
            # Adjust k: We are now looking for the (k - left_count_diff)-th element in the right subtree
            return query_kth(next_node_R, next_node_L_1, mid + 1, high, k - left_count_diff)

    # Query function to find the total count and sum of elements with rank <= target_rank in range A[L..R]
    def query_sum_count(node_R, node_L_1, low, high, target_rank):
        # If both nodes corresponding to this value range are None, return 0 count and sum
        if node_R is None and node_L_1 is None:
             return 0, 0
        
        # Optimization: If the current segment [low, high] is entirely greater than target_rank, contribute nothing
        if low > target_rank: 
             return 0, 0

        # Extract counts and sums from nodes R and L-1, handling None cases
        count_R = node_R.count if node_R else 0
        sum_R = node_R.total_sum if node_R else 0
        count_L_1 = node_L_1.count if node_L_1 else 0
        sum_L_1 = node_L_1.total_sum if node_L_1 else 0
        
        # If the current segment [low, high] is fully contained within the target range [0, target_rank]
        if high <= target_rank:
            # The entire count and sum difference for this node contributes to the result
            count_diff = count_R - count_L_1
            sum_diff = sum_R - sum_L_1
            return count_diff, sum_diff

        # Recursive step: Current segment partially overlaps target range. Explore children.
        mid = (low + high) // 2
        
        # Query left child
        next_node_R_left = node_R.left if node_R else None
        next_node_L_1_left = node_L_1.left if node_L_1 else None
        res_left = query_sum_count(next_node_R_left, next_node_L_1_left, low, mid, target_rank)

        # Query right child
        next_node_R_right = node_R.right if node_R else None
        next_node_L_1_right = node_L_1.right if node_L_1 else None
        res_right = query_sum_count(next_node_R_right, next_node_L_1_right, mid + 1, high, target_rank)
        
        # Combine results from left and right children
        return res_left[0] + res_right[0], res_left[1] + res_right[1]

    # Precompute prefix sums for O(1) retrieval of total sum in any range A[L..R]
    prefix_sum = [0] * (N + 1)
    current_sum = 0
    for i in range(N):
        current_sum += A[i]
        prefix_sum[i+1] = current_sum

    # Use a list buffer to store results and print all at once using join for efficiency
    results_buffer = [] 
    
    # Process Q queries
    for _ in range(Q):
        L, R = map(int, readline().split()) # Read query range L, R (1-based)
        m = R - L + 1 # Number of elements in the range A[L..R]
        
        # Calculate the rank of the median. For m elements, median is the ceil(m/2)-th element.
        k_med = (m + 1) // 2 
        
        # Get the roots of the persistent segment trees for the range endpoints
        root_R = roots[R]
        root_L_1 = roots[L-1]
        
        median_rank = -1 # Default value, indicates issue if U=0
        M = 0 # Default median value
        
        # Find the median value M if the range is valid and array has elements
        if U > 0: # Check if there are any unique values (N>=1 guarantees this)
             # Find the rank (0 to U-1) of the k_med-th element in the subarray
             median_rank = query_kth(root_R, root_L_1, 0, U - 1, k_med)
             # Convert the rank back to the actual median value
             M = rank_to_val[median_rank] 
        
        # Query count (C_le_M) and sum (Sum_le_M) of elements <= M in range A[L..R]
        C_le_M, Sum_le_M = 0, 0 # Defaults for empty case U=0
        if U > 0 and median_rank != -1: # Ensure median rank was found
             C_le_M, Sum_le_M = query_sum_count(root_R, root_L_1, 0, U - 1, median_rank)
        
        # Calculate total sum (Sum_total) in range A[L..R] using prefix sums
        Sum_total = prefix_sum[R] - prefix_sum[L-1]
        
        # Calculate the minimum value of f(x) using the derived formula:
        # Min value = M * (2 * C_le_M - m) + Sum_total - 2 * Sum_le_M
        min_val = M * (2 * C_le_M - m) + Sum_total - 2 * Sum_le_M
        
        # Append the calculated minimum value to the results buffer as a string
        results_buffer.append(str(min_val)) 
        
    # Print all results, separated by newlines
    sys.stdout.write("\n".join(results_buffer) + "\n")

solve()
0