結果

問題 No.924 紲星
ユーザー lam6er
提出日時 2025-03-31 17:39:29
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,145 bytes
コンパイル時間 290 ms
コンパイル使用メモリ 82,084 KB
実行使用メモリ 512,756 KB
最終ジャッジ日時 2025-03-31 17:40:10
合計ジャッジ時間 7,570 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 5 MLE * 1 -- * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

class Node:
    __slots__ = ['left', 'right', 'cnt', 'sum']
    def __init__(self, left=None, right=None, cnt=0, sum=0):
        self.left = left
        self.right = right
        self.cnt = cnt
        self.sum = sum

def build(l, r):
    if l == r:
        return Node(None, None, 0, 0)
    mid = (l + r) // 2
    left = build(l, mid)
    right = build(mid+1, r)
    return Node(left, right, left.cnt + right.cnt, left.sum + right.sum)

def update(node, l, r, idx, value):
    if l == r:
        return Node(None, None, node.cnt + 1, node.sum + value)
    mid = (l + r) // 2
    if idx <= mid:
        new_left = update(node.left, l, mid, idx, value)
        new_right = node.right
    else:
        new_left = node.left
        new_right = update(node.right, mid+1, r, idx, value)
    return Node(new_left, new_right, new_left.cnt + new_right.cnt, new_left.sum + new_right.sum)

def find_kth(node_l, node_r, l, r, k):
    if l == r:
        return l
    mid = (l + r) // 2
    cnt_left = node_r.left.cnt - node_l.left.cnt
    if cnt_left >= k:
        return find_kth(node_l.left, node_r.left, l, mid, k)
    else:
        return find_kth(node_l.right, node_r.right, mid+1, r, k - cnt_left)

def get_sum_and_cnt(node_l, node_r, x, l, r):
    if r <= x:
        return (node_r.sum - node_l.sum, node_r.cnt - node_l.cnt)
    mid = (l + r) // 2
    sum_total = 0
    cnt_total = 0
    s_left, c_left = get_sum_and_cnt(node_l.left, node_r.left, x, l, mid)
    sum_total += s_left
    cnt_total += c_left
    if x > mid:
        s_right, c_right = get_sum_and_cnt(node_l.right, node_r.right, x, mid+1, r)
        sum_total += s_right
        cnt_total += c_right
    return (sum_total, cnt_total)

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    ptr = 0
    N, Q = int(data[ptr]), int(data[ptr+1])
    ptr +=2
    A = list(map(int, data[ptr:ptr+N]))
    ptr +=N
    queries = []
    for _ in range(Q):
        L = int(data[ptr])
        R = int(data[ptr+1])
        queries.append( (L, R) )
        ptr +=2

    # Coordinate compression
    unique = sorted(set(A))
    map_val = {v:i for i, v in enumerate(unique)}
    M = len(unique)
    compressed = [map_val[v] for v in A]

    # Build persistent segment tree
    roots = []
    roots.append( build(0, M-1) )
    for i in range(N):
        val = A[i]
        idx = map_val[val]
        new_root = update(roots[-1], 0, M-1, idx, val)
        roots.append( new_root )

    # Process queries
    results = []
    for L, R in queries:
        if L > R:
            results.append(0)
            continue
        node_r = roots[R]
        node_l = roots[L-1]
        m = R - L + 1
        k = (m - 1) // 2 +1 # 1-based for k-th

        x_idx = find_kth(node_l, node_r, 0, M-1, k)
        x = unique[x_idx]

        sum_leq, cnt_leq = get_sum_and_cnt(node_l, node_r, x_idx, 0, M-1)
        sum_total = node_r.sum - node_l.sum

        # Calculate sum of absolute differences
        res = (sum_total - 2 * sum_leq) + x * (2 * cnt_leq - m)
        results.append(res)

    print('\n'.join(map(str, results)))

if __name__ == '__main__':
    main()
0