結果

問題 No.924 紲星
ユーザー gew1fw
提出日時 2025-06-12 13:50:38
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,757 ms / 4,000 ms
コード長 3,688 bytes
コンパイル時間 445 ms
コンパイル使用メモリ 82,776 KB
実行使用メモリ 391,660 KB
最終ジャッジ日時 2025-06-12 13:52:04
合計ジャッジ時間 22,964 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, Q = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    queries = []
    for _ in range(Q):
        L = int(input[ptr]) - 1
        R = int(input[ptr+1]) - 1
        queries.append((L, R))
        ptr += 2

    # Coordinate compression
    sorted_unique = sorted(list(set(A)))
    sorted_unique.sort()
    value_to_idx = {v: i for i, v in enumerate(sorted_unique)}
    A_compressed = [value_to_idx[v] for v in A]

    # Build prefix sums
    prefix = [0] * (N + 1)
    for i in range(N):
        prefix[i+1] = prefix[i] + A[i]

    # Build the wavelet tree structure
    class WaveletNode:
        def __init__(self, lo, hi):
            self.lo = sorted_unique[lo] if lo < len(sorted_unique) else 0
            self.hi = sorted_unique[hi] if hi < len(sorted_unique) else 0
            self.left = None
            self.right = None
            self.bit = []
            self.left_count = [0]
            self.left_sum = [0]

    def build(l, r, values, lo, hi):
        node = WaveletNode(lo, hi)
        if lo >= hi:
            node.bit = [0] * (r - l + 1)
            cnt = 0
            s = 0
            node.left_count = [0]
            node.left_sum = [0]
            for i in range(l, r + 1):
                cnt += 1
                s += values[i - l]
                node.left_count.append(cnt)
                node.left_sum.append(s)
            return node
        mid_idx = (lo + hi) // 2
        mid_val = sorted_unique[mid_idx]
        left_part = []
        right_part = []
        cnt = 0
        s = 0
        node.left_count = [0]
        node.left_sum = [0]
        for i in range(l, r + 1):
            val = values[i - l]
            if val <= mid_val:
                left_part.append(val)
                node.bit.append(0)
                cnt += 1
                s += val
            else:
                right_part.append(val)
                node.bit.append(1)
            node.left_count.append(cnt)
            node.left_sum.append(s)
        node.left = build(0, len(left_part)-1, left_part, lo, mid_idx)
        node.right = build(0, len(right_part)-1, right_part, mid_idx + 1, hi)
        return node

    root = build(0, N-1, A, 0, len(sorted_unique)-1)

    def kth(node, L, R, k):
        if node.lo == node.hi:
            return node.lo
        cnt_left = node.left_count[R+1] - node.left_count[L]
        if k <= cnt_left:
            new_L = node.left_count[L]
            new_R = node.left_count[R+1] - 1
            return kth(node.left, new_L, new_R, k)
        else:
            new_L = L - node.left_count[L]
            new_R = R - node.left_count[R+1]
            return kth(node.right, new_L, new_R, k - cnt_left)

    def sum_k(node, L, R, k):
        if node.lo == node.hi:
            return node.lo * k
        cnt_left = node.left_count[R+1] - node.left_count[L]
        sum_left = node.left_sum[R+1] - node.left_sum[L]
        if k <= cnt_left:
            new_L = node.left_count[L]
            new_R = node.left_count[R+1] - 1
            return sum_k(node.left, new_L, new_R, k)
        else:
            new_L = L - node.left_count[L]
            new_R = R - node.left_count[R+1]
            return sum_left + sum_k(node.right, new_L, new_R, k - cnt_left)

    for L, R in queries:
        length = R - L + 1
        k = (length + 1) // 2
        m = kth(root, L, R, k)
        sum_less = sum_k(root, L, R, k)
        sum_total = prefix[R+1] - prefix[L]
        answer = (m * k - sum_less) + (sum_total - sum_less - m * (length - k))
        print(answer)

if __name__ == "__main__":
    main()
0