結果
問題 |
No.924 紲星
|
ユーザー |
![]() |
提出日時 | 2025-06-12 13:50:38 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,757 ms / 4,000 ms |
コード長 | 3,688 bytes |
コンパイル時間 | 445 ms |
コンパイル使用メモリ | 82,776 KB |
実行使用メモリ | 391,660 KB |
最終ジャッジ日時 | 2025-06-12 13:52:04 |
合計ジャッジ時間 | 22,964 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 16 |
ソースコード
import sys def main(): input = sys.stdin.read().split() ptr = 0 N, Q = int(input[ptr]), int(input[ptr+1]) ptr += 2 A = list(map(int, input[ptr:ptr+N])) ptr += N queries = [] for _ in range(Q): L = int(input[ptr]) - 1 R = int(input[ptr+1]) - 1 queries.append((L, R)) ptr += 2 # Coordinate compression sorted_unique = sorted(list(set(A))) sorted_unique.sort() value_to_idx = {v: i for i, v in enumerate(sorted_unique)} A_compressed = [value_to_idx[v] for v in A] # Build prefix sums prefix = [0] * (N + 1) for i in range(N): prefix[i+1] = prefix[i] + A[i] # Build the wavelet tree structure class WaveletNode: def __init__(self, lo, hi): self.lo = sorted_unique[lo] if lo < len(sorted_unique) else 0 self.hi = sorted_unique[hi] if hi < len(sorted_unique) else 0 self.left = None self.right = None self.bit = [] self.left_count = [0] self.left_sum = [0] def build(l, r, values, lo, hi): node = WaveletNode(lo, hi) if lo >= hi: node.bit = [0] * (r - l + 1) cnt = 0 s = 0 node.left_count = [0] node.left_sum = [0] for i in range(l, r + 1): cnt += 1 s += values[i - l] node.left_count.append(cnt) node.left_sum.append(s) return node mid_idx = (lo + hi) // 2 mid_val = sorted_unique[mid_idx] left_part = [] right_part = [] cnt = 0 s = 0 node.left_count = [0] node.left_sum = [0] for i in range(l, r + 1): val = values[i - l] if val <= mid_val: left_part.append(val) node.bit.append(0) cnt += 1 s += val else: right_part.append(val) node.bit.append(1) node.left_count.append(cnt) node.left_sum.append(s) node.left = build(0, len(left_part)-1, left_part, lo, mid_idx) node.right = build(0, len(right_part)-1, right_part, mid_idx + 1, hi) return node root = build(0, N-1, A, 0, len(sorted_unique)-1) def kth(node, L, R, k): if node.lo == node.hi: return node.lo cnt_left = node.left_count[R+1] - node.left_count[L] if k <= cnt_left: new_L = node.left_count[L] new_R = node.left_count[R+1] - 1 return kth(node.left, new_L, new_R, k) else: new_L = L - node.left_count[L] new_R = R - node.left_count[R+1] return kth(node.right, new_L, new_R, k - cnt_left) def sum_k(node, L, R, k): if node.lo == node.hi: return node.lo * k cnt_left = node.left_count[R+1] - node.left_count[L] sum_left = node.left_sum[R+1] - node.left_sum[L] if k <= cnt_left: new_L = node.left_count[L] new_R = node.left_count[R+1] - 1 return sum_k(node.left, new_L, new_R, k) else: new_L = L - node.left_count[L] new_R = R - node.left_count[R+1] return sum_left + sum_k(node.right, new_L, new_R, k - cnt_left) for L, R in queries: length = R - L + 1 k = (length + 1) // 2 m = kth(root, L, R, k) sum_less = sum_k(root, L, R, k) sum_total = prefix[R+1] - prefix[L] answer = (m * k - sum_less) + (sum_total - sum_less - m * (length - k)) print(answer) if __name__ == "__main__": main()