結果
| 問題 |
No.924 紲星
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 18:51:37 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 3,467 ms / 4,000 ms |
| コード長 | 3,688 bytes |
| コンパイル時間 | 213 ms |
| コンパイル使用メモリ | 82,092 KB |
| 実行使用メモリ | 391,524 KB |
| 最終ジャッジ日時 | 2025-06-12 18:52:14 |
| 合計ジャッジ時間 | 28,188 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 16 |
ソースコード
import sys
def main():
input = sys.stdin.read().split()
ptr = 0
N, Q = int(input[ptr]), int(input[ptr+1])
ptr += 2
A = list(map(int, input[ptr:ptr+N]))
ptr += N
queries = []
for _ in range(Q):
L = int(input[ptr]) - 1
R = int(input[ptr+1]) - 1
queries.append((L, R))
ptr += 2
# Coordinate compression
sorted_unique = sorted(list(set(A)))
sorted_unique.sort()
value_to_idx = {v: i for i, v in enumerate(sorted_unique)}
A_compressed = [value_to_idx[v] for v in A]
# Build prefix sums
prefix = [0] * (N + 1)
for i in range(N):
prefix[i+1] = prefix[i] + A[i]
# Build the wavelet tree structure
class WaveletNode:
def __init__(self, lo, hi):
self.lo = sorted_unique[lo] if lo < len(sorted_unique) else 0
self.hi = sorted_unique[hi] if hi < len(sorted_unique) else 0
self.left = None
self.right = None
self.bit = []
self.left_count = [0]
self.left_sum = [0]
def build(l, r, values, lo, hi):
node = WaveletNode(lo, hi)
if lo >= hi:
node.bit = [0] * (r - l + 1)
cnt = 0
s = 0
node.left_count = [0]
node.left_sum = [0]
for i in range(l, r + 1):
cnt += 1
s += values[i - l]
node.left_count.append(cnt)
node.left_sum.append(s)
return node
mid_idx = (lo + hi) // 2
mid_val = sorted_unique[mid_idx]
left_part = []
right_part = []
cnt = 0
s = 0
node.left_count = [0]
node.left_sum = [0]
for i in range(l, r + 1):
val = values[i - l]
if val <= mid_val:
left_part.append(val)
node.bit.append(0)
cnt += 1
s += val
else:
right_part.append(val)
node.bit.append(1)
node.left_count.append(cnt)
node.left_sum.append(s)
node.left = build(0, len(left_part)-1, left_part, lo, mid_idx)
node.right = build(0, len(right_part)-1, right_part, mid_idx + 1, hi)
return node
root = build(0, N-1, A, 0, len(sorted_unique)-1)
def kth(node, L, R, k):
if node.lo == node.hi:
return node.lo
cnt_left = node.left_count[R+1] - node.left_count[L]
if k <= cnt_left:
new_L = node.left_count[L]
new_R = node.left_count[R+1] - 1
return kth(node.left, new_L, new_R, k)
else:
new_L = L - node.left_count[L]
new_R = R - node.left_count[R+1]
return kth(node.right, new_L, new_R, k - cnt_left)
def sum_k(node, L, R, k):
if node.lo == node.hi:
return node.lo * k
cnt_left = node.left_count[R+1] - node.left_count[L]
sum_left = node.left_sum[R+1] - node.left_sum[L]
if k <= cnt_left:
new_L = node.left_count[L]
new_R = node.left_count[R+1] - 1
return sum_k(node.left, new_L, new_R, k)
else:
new_L = L - node.left_count[L]
new_R = R - node.left_count[R+1]
return sum_left + sum_k(node.right, new_L, new_R, k - cnt_left)
for L, R in queries:
length = R - L + 1
k = (length + 1) // 2
m = kth(root, L, R, k)
sum_less = sum_k(root, L, R, k)
sum_total = prefix[R+1] - prefix[L]
answer = (m * k - sum_less) + (sum_total - sum_less - m * (length - k))
print(answer)
if __name__ == "__main__":
main()
gew1fw