結果

問題 No.924 紲星
ユーザー lam6er
提出日時 2025-04-16 01:16:37
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,734 bytes
コンパイル時間 272 ms
コンパイル使用メモリ 81,860 KB
実行使用メモリ 316,176 KB
最終ジャッジ日時 2025-04-16 01:16:57
合計ジャッジ時間 7,071 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 5 TLE * 1 -- * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    Q = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    # Precompute prefix sum
    prefix = [0] * (N + 1)
    for i in range(N):
        prefix[i+1] = prefix[i] + A[i]

    # Build the Merge Sort Tree
    class Node:
        __slots__ = ['sorted_list', 'prefix_sum', 'left', 'right']
        def __init__(self, sorted_list, prefix_sum):
            self.sorted_list = sorted_list
            self.prefix_sum = prefix_sum
            self.left = None
            self.right = None

    def build(l, r):
        if l == r:
            sl = [A[l]]
            ps = [A[l]]
            return Node(sl, ps)
        mid = (l + r) // 2
        left_node = build(l, mid)
        right_node = build(mid + 1, r)
        # Merge the two sorted lists
        merged = []
        i = j = 0
        while i < len(left_node.sorted_list) and j < len(right_node.sorted_list):
            if left_node.sorted_list[i] <= right_node.sorted_list[j]:
                merged.append(left_node.sorted_list[i])
                i += 1
            else:
                merged.append(right_node.sorted_list[j])
                j += 1
        merged.extend(left_node.sorted_list[i:])
        merged.extend(right_node.sorted_list[j:])
        # Compute prefix sum
        ps = []
        s = 0
        for num in merged:
            s += num
            ps.append(s)
        node = Node(merged, ps)
        node.left = left_node
        node.right = right_node
        return node

    root = build(0, N-1)

    # Query function for count and sum of elements <= x in [l, r] (0-based)
    def query_count_sum(l, r, x):
        # Convert to 0-based
        l -= 1
        r -= 1
        result_cnt = 0
        result_sum = 0
        stack = [(root, 0, N-1)]
        while stack:
            node, node_l, node_r = stack.pop()
            if node_r < l or node_l > r:
                continue
            if l <= node_l and node_r <= r:
                # Binary search in the sorted_list
                cnt = bisect.bisect_right(node.sorted_list, x)
                sum_val = node.prefix_sum[cnt-1] if cnt > 0 else 0
                result_cnt += cnt
                result_sum += sum_val
                continue
            mid = (node_l + node_r) // 2
            if node.right:
                stack.append((node.right, mid+1, node_r))
            if node.left:
                stack.append((node.left, node_l, mid))
        return (result_cnt, result_sum)

    # Precompute sorted unique elements for binary search
    sorted_unique = sorted(set(A))
    sorted_unique.sort()

    # Process each query
    output = []
    for _ in range(Q):
        L = int(input[ptr])
        ptr += 1
        R = int(input[ptr])
        ptr += 1
        sum_total = prefix[R] - prefix[L-1]
        length = R - L + 1
        k = (length + 1) // 2

        # Binary search on sorted_unique to find the median
        low = 0
        high = len(sorted_unique) - 1
        median = sorted_unique[-1]
        while low <= high:
            mid = (low + high) // 2
            candidate = sorted_unique[mid]
            cnt, _ = query_count_sum(L, R, candidate)
            if cnt >= k:
                median = candidate
                high = mid - 1
            else:
                low = mid + 1

        # Get the sum and count of elements <= median
        cnt_low, sum_low = query_count_sum(L, R, median)
        ans = (median * cnt_low - sum_low) + (sum_total - sum_low - median * (length - cnt_low))
        output.append(str(ans))
    
    print('\n'.join(output))

if __name__ == "__main__":
    main()
0