結果
| 問題 |
No.2065 Sum of Min
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 23:47:04 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,303 ms / 2,000 ms |
| コード長 | 2,389 bytes |
| コンパイル時間 | 189 ms |
| コンパイル使用メモリ | 82,724 KB |
| 実行使用メモリ | 190,424 KB |
| 最終ジャッジ日時 | 2025-04-15 23:49:35 |
| 合計ジャッジ時間 | 20,031 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 20 |
ソースコード
import sys
import bisect
import itertools
class Node:
__slots__ = ['start', 'end', 'left', 'right', 'sorted_list', 'prefix_sum']
def __init__(self, start, end):
self.start = start
self.end = end
self.left = None
self.right = None
self.sorted_list = []
self.prefix_sum = []
def merge(left, right):
merged = []
i = j = 0
len_left = len(left)
len_right = len(right)
while i < len_left and j < len_right:
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
merged.append(right[j])
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged
def build_segment_tree(arr, l, r):
node = Node(l, r)
if l == r:
node.sorted_list = [arr[l]]
node.prefix_sum = list(itertools.accumulate(node.sorted_list, initial=0))
else:
mid = (l + r) // 2
node.left = build_segment_tree(arr, l, mid)
node.right = build_segment_tree(arr, mid+1, r)
merged = merge(node.left.sorted_list, node.right.sorted_list)
node.sorted_list = merged
node.prefix_sum = list(itertools.accumulate(merged, initial=0))
return node
def query_sum_less_equal(root, L, R, X):
stack = [root]
total_sum = 0
total_count = 0
while stack:
node = stack.pop()
if node.end < L or node.start > R:
continue
if L <= node.start and node.end <= R:
idx = bisect.bisect_right(node.sorted_list, X)
total_sum += node.prefix_sum[idx]
total_count += idx
else:
stack.append(node.right)
stack.append(node.left)
return (total_sum, total_count)
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr += 1
Q = int(input[ptr])
ptr += 1
A = list(map(int, input[ptr:ptr+N]))
ptr += N
root = build_segment_tree(A, 0, N-1)
for _ in range(Q):
L = int(input[ptr])
ptr += 1
R = int(input[ptr])
ptr += 1
X = int(input[ptr])
ptr += 1
L -= 1
R -= 1
sum_less, count_less = query_sum_less_equal(root, L, R, X)
count_greater = (R - L + 1) - count_less
answer = sum_less + X * count_greater
print(answer)
if __name__ == '__main__':
main()
lam6er