結果

問題 No.924 紲星
ユーザー gew1fw
提出日時 2025-06-12 18:52:19
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,588 bytes
コンパイル時間 134 ms
コンパイル使用メモリ 82,388 KB
実行使用メモリ 78,088 KB
最終ジャッジ日時 2025-06-12 18:52:33
合計ジャッジ時間 7,542 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 5 TLE * 1 -- * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect
import sys
import math

class SegmentTreeNode:
    def __init__(self, l, r):
        self.l = l
        self.r = r
        self.left_child = None
        self.right_child = None
        self.sorted_list = []
        self.prefix_sum = []

def merge(left, right):
    merged = []
    i = j = 0
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            merged.append(left[i])
            i += 1
        else:
            merged.append(right[j])
            j += 1
    merged.extend(left[i:])
    merged.extend(right[j:])
    return merged

def build_segment_tree(l, r, A):
    node = SegmentTreeNode(l, r)
    if l == r:
        node.sorted_list = [A[l-1]]
        node.prefix_sum = [0]
        for num in node.sorted_list:
            node.prefix_sum.append(node.prefix_sum[-1] + num)
    else:
        mid = (l + r) // 2
        node.left_child = build_segment_tree(l, mid, A)
        node.right_child = build_segment_tree(mid+1, r, A)
        node.sorted_list = merge(node.left_child.sorted_list, node.right_child.sorted_list)
        node.prefix_sum = [0]
        for num in node.sorted_list:
            node.prefix_sum.append(node.prefix_sum[-1] + num)
    return node

def count_less_or_equal(x, L, R, root):
    stack = [root]
    count = 0
    while stack:
        node = stack.pop()
        if node.r < L or node.l > R:
            continue
        if L <= node.l and node.r <= R:
            cnt = bisect.bisect_right(node.sorted_list, x)
            count += cnt
        else:
            stack.append(node.right_child)
            stack.append(node.left_child)
    return count

def sum_less_or_equal(x, L, R, root):
    stack = [root]
    total = 0
    while stack:
        node = stack.pop()
        if node.r < L or node.l > R:
            continue
        if L <= node.l and node.r <= R:
            idx = bisect.bisect_right(node.sorted_list, x)
            total += node.prefix_sum[idx]
        else:
            stack.append(node.right_child)
            stack.append(node.left_child)
    return total

def build_sparse_table(data, func, log_table):
    n = len(data)
    k_max = log_table[n] + 1
    st = []
    st.append(data.copy())
    for k in range(1, k_max):
        curr = []
        for i in range(n - (1 << k) + 1):
            curr_val = func(st[k-1][i], st[k-1][i + (1 << (k-1))])
            curr.append(curr_val)
        st.append(curr)
    return st

def query_sparse_table(st, log_table, l, r, func):
    length = r - l
    if length == 0:
        return None
    k = log_table[length]
    val1 = st[k][l]
    val2 = st[k][r - (1 << k)]
    return func(val1, val2)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, Q = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    queries = []
    for _ in range(Q):
        L = int(input[ptr])
        R = int(input[ptr+1])
        queries.append((L, R))
        ptr += 2

    # Build prefix sum array
    prefix_sum = [0] * (N + 1)
    for i in range(N):
        prefix_sum[i+1] = prefix_sum[i] + A[i]

    # Build sparse tables for min and max
    data = A
    n = len(data)
    log_table = [0] * (n + 1)
    for i in range(2, n + 1):
        log_table[i] = log_table[i // 2] + 1
    st_min = build_sparse_table(data, min, log_table)
    st_max = build_sparse_table(data, max, log_table)

    # Build segment tree for sorted lists and prefix sums
    root = build_segment_tree(1, N, A)

    for L, R in queries:
        l = L - 1  # 0-based
        r = R
        len_sub = R - L + 1
        k = (len_sub + 1) // 2

        # Find min and max in [L, R]
        min_val = query_sparse_table(st_min, log_table, l, r, min)
        max_val = query_sparse_table(st_max, log_table, l, r, max)

        # Binary search for the k-th smallest element
        low = min_val
        high = max_val
        answer_m = high
        while low <= high:
            mid = (low + high) // 2
            cnt = count_less_or_equal(mid, L, R, root)
            if cnt >= k:
                answer_m = mid
                high = mid - 1
            else:
                low = mid + 1
        m = answer_m

        # Compute sum_less and cnt_less
        sum_less = sum_less_or_equal(m, L, R, root)
        cnt_less = count_less_or_equal(m, L, R, root)
        sum_total = prefix_sum[R] - prefix_sum[L-1]
        sum_greater = sum_total - sum_less
        cnt_greater = len_sub - cnt_less

        ans = m * cnt_less - sum_less + sum_greater - m * cnt_greater
        print(ans)

if __name__ == '__main__':
    main()
0