結果

問題 No.924 紲星
ユーザー lam6er
提出日時 2025-03-20 21:18:59
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,264 bytes
コンパイル時間 314 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 492,668 KB
最終ジャッジ日時 2025-03-20 21:20:32
合計ジャッジ時間 8,406 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 11 TLE * 5
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

import bisect
class PersistentSegmentTreeNode:
__slots__ = ['left', 'right', 'count', 'sum_val']
def __init__(self, left=None, right=None, count=0, sum_val=0):
self.left = left
self.right = right
self.count = count
self.sum_val = sum_val
def build(l, r, B):
node = PersistentSegmentTreeNode()
if l == r:
return node
mid = (l + r) // 2
node.left = build(l, mid, B)
node.right = build(mid+1, r, B)
return node
def update(old_node, l, r, target_idx, value, B):
new_node = PersistentSegmentTreeNode()
new_node.count = old_node.count + 1
new_node.sum_val = old_node.sum_val + value
if l == r:
return new_node
mid = (l + r) // 2
if target_idx <= mid:
new_node.left = update(old_node.left, l, mid, target_idx, value, B)
new_node.right = old_node.right
else:
new_node.left = old_node.left
new_node.right = update(old_node.right, mid+1, r, target_idx, value, B)
return new_node
def find_kth(node1, node2, l, r, k, B):
if l == r:
return B[l]
mid = (l + r) // 2
left_count = node1.left.count - node2.left.count
if left_count >= k:
return find_kth(node1.left, node2.left, l, mid, k, B)
else:
return find_kth(node1.right, node2.right, mid+1, r, k - left_count, B)
def query_sum(node1, node2, l, r, target_pos, B):
if r <= target_pos:
return (node1.count - node2.count, node1.sum_val - node2.sum_val)
mid = (l + r) // 2
if target_pos <= mid:
return query_sum(node1.left, node2.left, l, mid, target_pos, B)
else:
left_count, left_sum = query_sum(node1.left, node2.left, l, mid, target_pos, B)
right_count, right_sum = query_sum(node1.right, node2.right, mid+1, r, target_pos, B)
return (left_count + right_count, left_sum + right_sum)
def main():
import sys
input = sys.stdin.read
data = input().split()
idx = 0
N = int(data[idx]); idx +=1
Q = int(data[idx]); idx +=1
A = list(map(int, data[idx:idx+N]))
idx += N
B = sorted(list(set(A)))
if not B:
B.append(0)
B.sort()
pre_sum = [0]*(N+1)
for i in range(1, N+1):
pre_sum[i] = pre_sum[i-1] + A[i-1]
len_B = len(B)
versions = [None]*(N+1)
versions[0] = build(0, len_B-1, B)
for i in range(1, N+1):
x = A[i-1]
pos = bisect.bisect_left(B, x)
versions[i] = update(versions[i-1], 0, len_B-1, pos, x, B)
for _ in range(Q):
L = int(data[idx]); idx +=1
R = int(data[idx]); idx +=1
m = R - L +1
k = (m +1) // 2
root_R = versions[R]
root_L_minus_1 = versions[L-1]
if len_B == 0:
x = 0
else:
x = find_kth(root_R, root_L_minus_1, 0, len_B-1, k, B)
pos = bisect.bisect_right(B, x) -1
if len_B ==0:
count, sum_left = 0, 0
else:
count, sum_left = query_sum(root_R, root_L_minus_1, 0, len_B-1, pos, B)
total_sum = pre_sum[R] - pre_sum[L-1]
sum_right = total_sum - sum_left
right_count = m - count
ans = x * count - sum_left + (sum_right - x * right_count)
print(ans)
if __name__ == "__main__":
main()
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0