結果
| 問題 |
No.924 紲星
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:39:29 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 3,145 bytes |
| コンパイル時間 | 290 ms |
| コンパイル使用メモリ | 82,084 KB |
| 実行使用メモリ | 512,756 KB |
| 最終ジャッジ日時 | 2025-03-31 17:40:10 |
| 合計ジャッジ時間 | 7,570 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 5 MLE * 1 -- * 10 |
ソースコード
import bisect
class Node:
__slots__ = ['left', 'right', 'cnt', 'sum']
def __init__(self, left=None, right=None, cnt=0, sum=0):
self.left = left
self.right = right
self.cnt = cnt
self.sum = sum
def build(l, r):
if l == r:
return Node(None, None, 0, 0)
mid = (l + r) // 2
left = build(l, mid)
right = build(mid+1, r)
return Node(left, right, left.cnt + right.cnt, left.sum + right.sum)
def update(node, l, r, idx, value):
if l == r:
return Node(None, None, node.cnt + 1, node.sum + value)
mid = (l + r) // 2
if idx <= mid:
new_left = update(node.left, l, mid, idx, value)
new_right = node.right
else:
new_left = node.left
new_right = update(node.right, mid+1, r, idx, value)
return Node(new_left, new_right, new_left.cnt + new_right.cnt, new_left.sum + new_right.sum)
def find_kth(node_l, node_r, l, r, k):
if l == r:
return l
mid = (l + r) // 2
cnt_left = node_r.left.cnt - node_l.left.cnt
if cnt_left >= k:
return find_kth(node_l.left, node_r.left, l, mid, k)
else:
return find_kth(node_l.right, node_r.right, mid+1, r, k - cnt_left)
def get_sum_and_cnt(node_l, node_r, x, l, r):
if r <= x:
return (node_r.sum - node_l.sum, node_r.cnt - node_l.cnt)
mid = (l + r) // 2
sum_total = 0
cnt_total = 0
s_left, c_left = get_sum_and_cnt(node_l.left, node_r.left, x, l, mid)
sum_total += s_left
cnt_total += c_left
if x > mid:
s_right, c_right = get_sum_and_cnt(node_l.right, node_r.right, x, mid+1, r)
sum_total += s_right
cnt_total += c_right
return (sum_total, cnt_total)
def main():
import sys
input = sys.stdin.read
data = input().split()
ptr = 0
N, Q = int(data[ptr]), int(data[ptr+1])
ptr +=2
A = list(map(int, data[ptr:ptr+N]))
ptr +=N
queries = []
for _ in range(Q):
L = int(data[ptr])
R = int(data[ptr+1])
queries.append( (L, R) )
ptr +=2
# Coordinate compression
unique = sorted(set(A))
map_val = {v:i for i, v in enumerate(unique)}
M = len(unique)
compressed = [map_val[v] for v in A]
# Build persistent segment tree
roots = []
roots.append( build(0, M-1) )
for i in range(N):
val = A[i]
idx = map_val[val]
new_root = update(roots[-1], 0, M-1, idx, val)
roots.append( new_root )
# Process queries
results = []
for L, R in queries:
if L > R:
results.append(0)
continue
node_r = roots[R]
node_l = roots[L-1]
m = R - L + 1
k = (m - 1) // 2 +1 # 1-based for k-th
x_idx = find_kth(node_l, node_r, 0, M-1, k)
x = unique[x_idx]
sum_leq, cnt_leq = get_sum_and_cnt(node_l, node_r, x_idx, 0, M-1)
sum_total = node_r.sum - node_l.sum
# Calculate sum of absolute differences
res = (sum_total - 2 * sum_leq) + x * (2 * cnt_leq - m)
results.append(res)
print('\n'.join(map(str, results)))
if __name__ == '__main__':
main()
lam6er