結果

問題 No.924 紲星
ユーザー lam6er
提出日時 2025-03-20 21:18:59
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,264 bytes
コンパイル時間 314 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 492,668 KB
最終ジャッジ日時 2025-03-20 21:20:32
合計ジャッジ時間 8,406 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 11 TLE * 5
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

class PersistentSegmentTreeNode:
    __slots__ = ['left', 'right', 'count', 'sum_val']
    def __init__(self, left=None, right=None, count=0, sum_val=0):
        self.left = left
        self.right = right
        self.count = count
        self.sum_val = sum_val

def build(l, r, B):
    node = PersistentSegmentTreeNode()
    if l == r:
        return node
    mid = (l + r) // 2
    node.left = build(l, mid, B)
    node.right = build(mid+1, r, B)
    return node

def update(old_node, l, r, target_idx, value, B):
    new_node = PersistentSegmentTreeNode()
    new_node.count = old_node.count + 1
    new_node.sum_val = old_node.sum_val + value
    if l == r:
        return new_node
    mid = (l + r) // 2
    if target_idx <= mid:
        new_node.left = update(old_node.left, l, mid, target_idx, value, B)
        new_node.right = old_node.right
    else:
        new_node.left = old_node.left
        new_node.right = update(old_node.right, mid+1, r, target_idx, value, B)
    return new_node

def find_kth(node1, node2, l, r, k, B):
    if l == r:
        return B[l]
    mid = (l + r) // 2
    left_count = node1.left.count - node2.left.count
    if left_count >= k:
        return find_kth(node1.left, node2.left, l, mid, k, B)
    else:
        return find_kth(node1.right, node2.right, mid+1, r, k - left_count, B)

def query_sum(node1, node2, l, r, target_pos, B):
    if r <= target_pos:
        return (node1.count - node2.count, node1.sum_val - node2.sum_val)
    mid = (l + r) // 2
    if target_pos <= mid:
        return query_sum(node1.left, node2.left, l, mid, target_pos, B)
    else:
        left_count, left_sum = query_sum(node1.left, node2.left, l, mid, target_pos, B)
        right_count, right_sum = query_sum(node1.right, node2.right, mid+1, r, target_pos, B)
        return (left_count + right_count, left_sum + right_sum)

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    N = int(data[idx]); idx +=1
    Q = int(data[idx]); idx +=1
    A = list(map(int, data[idx:idx+N]))
    idx += N

    B = sorted(list(set(A)))
    if not B:
        B.append(0)
    B.sort()
    pre_sum = [0]*(N+1)
    for i in range(1, N+1):
        pre_sum[i] = pre_sum[i-1] + A[i-1]

    len_B = len(B)
    versions = [None]*(N+1)
    versions[0] = build(0, len_B-1, B)
    for i in range(1, N+1):
        x = A[i-1]
        pos = bisect.bisect_left(B, x)
        versions[i] = update(versions[i-1], 0, len_B-1, pos, x, B)

    for _ in range(Q):
        L = int(data[idx]); idx +=1
        R = int(data[idx]); idx +=1
        m = R - L +1
        k = (m +1) // 2
        root_R = versions[R]
        root_L_minus_1 = versions[L-1]
        if len_B == 0:
            x = 0
        else:
            x = find_kth(root_R, root_L_minus_1, 0, len_B-1, k, B)
        pos = bisect.bisect_right(B, x) -1
        if len_B ==0:
            count, sum_left = 0, 0
        else:
            count, sum_left = query_sum(root_R, root_L_minus_1, 0, len_B-1, pos, B)
        total_sum = pre_sum[R] - pre_sum[L-1]
        sum_right = total_sum - sum_left
        right_count = m - count
        ans = x * count - sum_left + (sum_right - x * right_count)
        print(ans)

if __name__ == "__main__":
    main()
0