結果
| 問題 |
No.924 紲星
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-04-14 23:09:12 |
| 言語 | Go (1.23.4) |
| 結果 |
AC
|
| 実行時間 | 626 ms / 4,000 ms |
| コード長 | 17,222 bytes |
| コンパイル時間 | 12,395 ms |
| コンパイル使用メモリ | 230,068 KB |
| 実行使用メモリ | 73,840 KB |
| 最終ジャッジ日時 | 2024-10-04 04:05:48 |
| 合計ジャッジ時間 | 17,159 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 16 |
ソースコード
// 维护区间贡献的 Wavelet Matrix
// !注意查询区间贡献时, 异或无效
// CountRange(start, end, a, b, xor) - 区间 [start, end) 中值在 [a, b) 之间的数的个数和这些数的和.
// CountPrefix(start, end, x, xor) - 区间 [start, end) 中值在 [0, x) 之间的数的个数和这些数的和.
// Kth(start, end, k, xor) - 区间 [start, end) 中第 k 小的数(0-indexed) 和前 k 小的数的和(不包括这个数).
// Floor(start, end, x, xor) - 区间 [start, end) 中值小于等于 x 的最大值
// Ceiling(start, end, x, xor) - 区间 [start, end) 中值大于等于 x 的最小值
// MaxRightValue(start, end, xor, check) - 返回使得 check(prefixSum) 为 true 的最大value, 其中prefixSum为[0,val)内的数的和.
// MaxRightCount(start, end, xor, check) - 返回使得 check(prefixSum) 为 true 的区间前缀个数的最大值.
package main
import (
"bufio"
"fmt"
"math/bits"
"os"
"sort"
)
func main() {
// demo()
区间最短距离和()
// abc281_e()
bv := NewBitVector(1 << 13)
bv.Set(3125)
bv.Build()
for i := range int(1 << 13) {
bv.Rank(int32(i), false)
}
}
func demo() {
nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
wm := NewWaveletMatrixWithSum(nums, nums, -1, false)
fmt.Println(wm.Count(0, 3, 3, 7, 0))
fmt.Println(wm.KthValueAndSum(1, 5, 0, 0))
nums = []int{1, 2, 3, 4, 5}
wm = NewWaveletMatrixWithSum(nums, nums, -1, false)
fmt.Println(wm.KthValueAndSum(0, 5, 2, 0))
fmt.Println(wm.KthValueAndSum(0, 6, 5, 0))
fmt.Println(wm.SumAll(0, 4))
}
func abc281_e() {
in := bufio.NewReader(os.Stdin)
out := bufio.NewWriter(os.Stdout)
defer out.Flush()
var n, m, k int32
fmt.Fscan(in, &n, &m, &k)
nums := make([]int, n)
for i := range nums {
fmt.Fscan(in, &nums[i])
}
wm := NewWaveletMatrixWithSum(nums, nums, -1, false)
for i := int32(0); i < n-m+1; i++ {
_, res := wm.KthValueAndSum(i, i+m, k, 0)
fmt.Fprintln(out, res)
}
}
func 区间最短距离和() {
// https://yukicoder.me/problems/no/924
// n,q<=2e5
// -1e9 <= nums[i] <= 1e9
// 给定n个查询[l,r]
// !求区间[l,r]中位数到区间[l,r]中每个数的距离之和
// !也就求函数 f(x)= ∑|nums[i]-x| (l<=i<=right) 的最小值
// !区间中位数
in := bufio.NewReader(os.Stdin)
out := bufio.NewWriter(os.Stdout)
defer out.Flush()
var n, q int
fmt.Fscan(in, &n, &q)
OFFSET := int(1e9 + 10)
nums := make([]int, n)
for i := range nums {
fmt.Fscan(in, &nums[i])
nums[i] += OFFSET
}
preSum := make([]int, n+1)
for i := 0; i < n; i++ {
preSum[i+1] = preSum[i] + nums[i]
}
wm := NewWaveletMatrixWithSum(nums, nums, -1, false)
for i := 0; i < q; i++ {
var start, end int32
fmt.Fscan(in, &start, &end)
start--
n := end - start
lowerCount := n / 2
ceilCount := n - lowerCount
mid, lowerSum := wm.KthValueAndSum(start, end, lowerCount, 0)
// allSum := preSum[end] - preSum[start]
_, allSum := wm.KthValueAndSum(start, end, n, 0)
ceilSum := allSum - lowerSum
res := 0
res += mid*int(lowerCount) - lowerSum
res += ceilSum - mid*int(ceilCount)
fmt.Fprintln(out, res)
}
}
type WmValue = int
type WmSum = int
func (*WaveletMatrixWithSum) e() WmSum { return 0 }
func (*WaveletMatrixWithSum) op(a, b WmSum) WmSum { return a + b }
func (*WaveletMatrixWithSum) inv(a WmSum) WmSum { return -a }
type WaveletMatrixWithSum struct {
n, log int32
mid []int32
bv []*BitVector
key []WmValue
setLog bool
presum [][]WmSum
compress bool
}
// nums: 数组元素.
// sumData: 和数据.nil表示不需要和数据.
// log: 如果需要支持异或查询, 需要传入log.-1表示默认.
// compress: 是否对nums进行离散化.
func NewWaveletMatrixWithSum(nums []WmValue, sumData []WmSum, log int32, compress bool) *WaveletMatrixWithSum {
wm := &WaveletMatrixWithSum{}
wm.build(nums, sumData, log, compress)
return wm
}
func (wm *WaveletMatrixWithSum) build(nums []WmValue, sumData []WmSum, log int32, compress bool) {
numsCopy := append(nums[:0:0], nums...)
sumDataCopy := append(sumData[:0:0], sumData...)
wm.n = int32(len(numsCopy))
wm.log = log
wm.compress = compress
wm.setLog = log != -1
if wm.n == 0 {
wm.log = 0
return
}
makeSum := len(sumData) > 0
if compress {
if wm.setLog {
panic("compress and log should not be set at the same time")
}
wm.key = make([]WmValue, 0, wm.n)
order := wm.argSort(numsCopy)
for _, i := range order {
if len(wm.key) == 0 || wm.key[len(wm.key)-1] != numsCopy[i] {
wm.key = append(wm.key, numsCopy[i])
}
numsCopy[i] = WmValue(len(wm.key) - 1)
}
}
if wm.log == -1 {
tmp := wm.maxs(numsCopy)
if tmp < 1 {
tmp = 1
}
wm.log = int32(bits.Len(uint(tmp)))
}
wm.mid = make([]int32, wm.log)
wm.bv = make([]*BitVector, wm.log)
for i := range wm.bv {
wm.bv[i] = NewBitVector(wm.n)
}
if makeSum {
wm.presum = make([][]WmSum, 1+wm.log)
for i := range wm.presum {
sums := make([]WmSum, wm.n+1)
for j := range sums {
sums[j] = wm.e()
}
wm.presum[i] = sums
}
}
if len(sumDataCopy) == 0 {
sumDataCopy = make([]WmSum, len(numsCopy))
}
A, S := numsCopy, sumDataCopy
A0, A1 := make([]WmValue, wm.n), make([]WmValue, wm.n)
S0, S1 := make([]WmSum, wm.n), make([]WmSum, wm.n)
for d := wm.log - 1; d >= -1; d-- {
p0, p1 := int32(0), int32(0)
if makeSum {
tmp := wm.presum[d+1]
for i := int32(0); i < wm.n; i++ {
tmp[i+1] = wm.op(tmp[i], S[i])
}
}
if d == -1 {
break
}
for i := int32(0); i < wm.n; i++ {
f := (A[i] >> d & 1) == 1
if !f {
if makeSum {
S0[p0] = S[i]
}
A0[p0] = A[i]
p0++
} else {
if makeSum {
S1[p1] = S[i]
}
wm.bv[d].Set(i)
A1[p1] = A[i]
p1++
}
}
wm.mid[d] = p0
wm.bv[d].Build()
A, A0 = A0, A
S, S0 = S0, S
for i := int32(0); i < p1; i++ {
A[p0+i] = A1[i]
S[p0+i] = S1[i]
}
}
}
// [start, end)区间内的元素个数.
func (wm *WaveletMatrixWithSum) Count(start, end int32, a, b WmValue, xorVal WmValue) int32 {
return wm._prefixCount(start, end, b, xorVal) - wm._prefixCount(start, end, a, xorVal)
}
// [start, end)区间内第k(k>=0)小的元素.
func (wm *WaveletMatrixWithSum) Kth(start, end, k int32, xorVal WmValue) WmValue {
if xorVal != 0 {
if !wm.setLog {
panic("log should be set")
}
}
count := int32(0)
res := WmValue(0)
for d := wm.log - 1; d >= 0; d-- {
f := (xorVal>>d)&1 == 1
l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
var c int32
if f {
c = (end - start) - (r0 - l0)
} else {
c = r0 - l0
}
if count+c > k {
if !f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
} else {
count += c
res |= 1 << d
if f {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
} else {
start, end = l0, r0
}
}
}
if wm.compress {
res = wm.key[res]
}
return res
}
// 返回区间 [start, end) 中的 (第k小的元素, 前k个元素(不包括第k小的元素) 的 op 的结果).
// 如果k >= end-start, 返回 (-1, 区间 op 的结果).
func (wm *WaveletMatrixWithSum) KthValueAndSum(start, end, k int32, xorVal WmValue) (WmValue, WmSum) {
if start >= end {
return -1, wm.e()
}
if k >= end-start {
return -1, wm.SumAll(start, end)
}
if xorVal != 0 {
if !wm.setLog {
panic("log should be set when xor is used")
}
}
if len(wm.presum) == 0 {
panic("sumData is not provided")
}
count := int32(0)
sum := wm.e()
res := WmValue(0)
for d := wm.log - 1; d >= 0; d-- {
f := (xorVal>>d)&1 == 1
l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
c := int32(0)
if f {
c = (end - start) - (r0 - l0)
} else {
c = r0 - l0
}
if count+c > k {
if !f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
} else {
var s WmSum
if f {
s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)
} else {
s = wm._get(d, l0, r0)
}
count += c
sum = wm.op(sum, s)
res |= 1 << d
if !f {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
} else {
start, end = l0, r0
}
}
}
sum = wm.op(sum, wm._get(0, start, start+k-count))
if wm.compress {
res = wm.key[res]
}
return res, sum
}
// upper: 向上取中位数还是向下取中位数.
func (wm *WaveletMatrixWithSum) Median(start, end int32, upper bool, xorVal WmValue) WmValue {
n := end - start
var k int32
if upper {
k = n / 2
} else {
k = (n - 1) / 2
}
return wm.Kth(start, end, k, xorVal)
}
func (wm *WaveletMatrixWithSum) Sum(start, end, k1, k2 int32, xorVal WmValue) WmSum {
if k1 >= k2 {
return wm.e()
}
add := wm._prefixSum(start, end, k2, xorVal)
sub := wm._prefixSum(start, end, k1, xorVal)
return wm.op(add, wm.inv(sub))
}
func (wm *WaveletMatrixWithSum) SumAll(start, end int32) WmSum {
return wm._get(wm.log, start, end)
}
// 使得predicate(count, sum)为true的最大的(count, sum).
func (wm *WaveletMatrixWithSum) MaxRight(predicate func(int32, WmSum) bool, start, end int32, xorVal WmValue) (int32, WmSum) {
if xorVal != 0 {
if !wm.setLog {
panic("log should be set when xor is used")
}
}
if start == end {
return end - start, wm.e()
}
if s := wm._get(wm.log, start, end); predicate(end-start, s) {
return end - start, s
}
count := int32(0)
sum := wm.e()
for d := wm.log - 1; d >= 0; d-- {
f := (xorVal>>d)&1 == 1
l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
c := int32(0)
if f {
c = (end - start) - (r0 - l0)
} else {
c = (r0 - l0)
}
var s WmSum
if f {
s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)
} else {
s = wm._get(d, l0, r0)
}
if tmp := wm.op(sum, s); predicate(count+c, tmp) {
count += c
sum = tmp
if f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
} else {
if !f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
}
}
k := wm.binarySearch(func(k int32) bool {
return predicate(count+k, wm.op(sum, wm._get(0, start, start+k)))
}, 0, end-start)
count += k
sum = wm.op(sum, wm._get(0, start, start+k))
return count, sum
}
func (wm *WaveletMatrixWithSum) CountSegments(segments [][2]int32, a, b WmValue, xorVal WmValue) int32 {
res := int32(0)
for _, seg := range segments {
res += wm.Count(seg[0], seg[1], a, b, xorVal)
}
return res
}
func (wm *WaveletMatrixWithSum) KthSegments(segments [][2]int32, k int32, xorVal WmValue) WmValue {
totalLen := int32(0)
for _, seg := range segments {
totalLen += seg[1] - seg[0]
}
count := int32(0)
res := WmValue(0)
for d := wm.log - 1; d >= 0; d-- {
f := (xorVal>>d)&1 == 1
c := int32(0)
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
if f {
c += (seg[1] - seg[0]) - (r0 - l0)
} else {
c += r0 - l0
}
}
if count+c > k {
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
if !f {
seg[0], seg[1] = l0, r0
} else {
seg[0] += wm.mid[d] - l0
seg[1] += wm.mid[d] - r0
}
}
} else {
count += c
res |= 1 << d
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
if f {
seg[0], seg[1] = l0, r0
} else {
seg[0] += wm.mid[d] - l0
seg[1] += wm.mid[d] - r0
}
}
}
}
if wm.compress {
res = wm.key[res]
}
return res
}
func (wm *WaveletMatrixWithSum) KthValueAndSumSegments(segments [][2]int32, k int32, xorVal WmValue) (WmValue, WmSum) {
if xorVal != 0 {
if !wm.setLog {
panic("log should be set when xor is used")
}
}
totalLen := int32(0)
for _, seg := range segments {
totalLen += seg[1] - seg[0]
}
if k >= totalLen {
return -1, wm.SumAllSegments(segments)
}
count := int32(0)
sum := wm.e()
res := WmValue(0)
for d := wm.log - 1; d >= 0; d-- {
f := (xorVal>>d)&1 == 1
c := int32(0)
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
if f {
c += (seg[1] - seg[0]) - (r0 - l0)
} else {
c += r0 - l0
}
}
if count+c > k {
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
if !f {
seg[0], seg[1] = l0, r0
} else {
seg[0] += wm.mid[d] - l0
seg[1] += wm.mid[d] - r0
}
}
} else {
count += c
res |= 1 << d
for _, seg := range segments {
l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
var s WmSum
if f {
s = wm._get(d, seg[0]+wm.mid[d]-l0, seg[1]+wm.mid[d]-r0)
} else {
s = wm._get(d, l0, r0)
}
sum = wm.op(sum, s)
if !f {
seg[0] += wm.mid[d] - l0
seg[1] += wm.mid[d] - r0
} else {
seg[0], seg[1] = l0, r0
}
}
}
}
for _, seg := range segments {
t := min32(seg[1]-seg[0], k-count)
sum = wm.op(sum, wm._get(0, seg[0], seg[0]+t))
count += t
}
if wm.compress {
res = wm.key[res]
}
return res, sum
}
func (wm *WaveletMatrixWithSum) MedianSegments(segments [][2]int32, upper bool, xorVal WmValue) WmValue {
n := int32(0)
for _, seg := range segments {
n += seg[1] - seg[0]
}
var k int32
if upper {
k = n / 2
} else {
k = (n - 1) / 2
}
return wm.KthSegments(segments, k, xorVal)
}
func (wm *WaveletMatrixWithSum) SumAllSegments(segments [][2]int32) WmSum {
sum := wm.e()
for _, seg := range segments {
sum = wm.op(sum, wm._get(wm.log, seg[0], seg[1]))
}
return sum
}
// 返回区间 [start, end) 中 范围在 [0, x) 中的元素的个数.
func (wm *WaveletMatrixWithSum) _prefixCount(start, end int32, x WmValue, xor WmValue) int32 {
if xor != 0 {
if !wm.setLog {
panic("log should be set when xor is used")
}
}
if wm.compress {
x = wm.lowerBound(wm.key, x)
}
if x == 0 {
return 0
}
if x >= 1<<wm.log {
return end - start
}
count := int32(0)
for d := wm.log - 1; d >= 0; d-- {
add := (x>>d)&1 == 1
f := (xor>>d)&1 == 1
l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
kf := int32(0)
if f {
kf = (end - start - r0 + l0)
} else {
kf = (r0 - l0)
}
if add {
count += kf
if f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
} else {
if !f {
start, end = l0, r0
} else {
start += wm.mid[d] - l0
end += wm.mid[d] - r0
}
}
}
return count
}
// 返回区间 [start, end) 中 [0, k) 的和.
func (wm *WaveletMatrixWithSum) _prefixSum(start, end, k int32, xor WmValue) WmSum {
_, sum := wm.KthValueAndSum(start, end, k, xor)
return sum
}
func (wm *WaveletMatrixWithSum) _prefixSumSegments(segments [][2]int32, k int32, xor WmValue) WmSum {
_, sum := wm.KthValueAndSumSegments(segments, k, xor)
return sum
}
func (wm *WaveletMatrixWithSum) _get(d, l, r int32) WmSum {
return wm.op(wm.inv(wm.presum[d][l]), wm.presum[d][r])
}
func (wm *WaveletMatrixWithSum) argSort(nums []WmValue) []int32 {
order := make([]int32, len(nums))
for i := range order {
order[i] = int32(i)
}
sort.Slice(order, func(i, j int) bool { return nums[order[i]] < nums[order[j]] })
return order
}
func (wm *WaveletMatrixWithSum) maxs(nums []WmValue) WmValue {
res := nums[0]
for _, v := range nums {
if v > res {
res = v
}
}
return res
}
func (wm *WaveletMatrixWithSum) lowerBound(nums []WmValue, target WmValue) WmValue {
left, right := int32(0), int32(len(nums)-1)
for left <= right {
mid := (left + right) >> 1
if nums[mid] < target {
left = mid + 1
} else {
right = mid - 1
}
}
return WmValue(left)
}
func (wm *WaveletMatrixWithSum) binarySearch(f func(int32) bool, ok, ng int32) int32 {
for abs32(ok-ng) > 1 {
x := (ok + ng) >> 1
if f(x) {
ok = x
} else {
ng = x
}
}
return ok
}
func abs32(x int32) int32 {
if x < 0 {
return -x
}
return x
}
func min32(a, b int32) int32 {
if a < b {
return a
}
return b
}
type BitVector struct {
data [][2]int
}
func NewBitVector(n int32) *BitVector {
return &BitVector{data: make([][2]int, (n+63)>>5)}
}
func (bv *BitVector) Set(i int32) {
bv.data[i>>5][0] |= 1 << (i & 31)
}
func (bv *BitVector) Build() {
for i := 0; i < len(bv.data)-1; i++ {
bv.data[i+1][1] = bv.data[i][1] + bits.OnesCount(uint(bv.data[i][0]))
}
}
// [0, k) 内の 1 の個数
func (bv *BitVector) Rank(k int32, f bool) int32 {
a, b := bv.data[k>>5][0], bv.data[k>>5][1]
ret := int32(b + bits.OnesCount(uint(a&((1<<(k&31))-1))))
if f {
return ret
}
return k - ret
}
// // TODO: uint64 or uint32 ??
// type BitVector struct {
// bits []uint64
// preSum []int32
// }
// func NewBitVector(n int32) *BitVector {
// return &BitVector{bits: make([]uint64, (n+63)>>6), preSum: make([]int32, (n+63)>>6)}
// }
// func (bv *BitVector) Set(i int32) {
// bv.bits[i>>6] |= 1 << (i & 63)
// }
// func (bv *BitVector) Build() {
// for i := 0; i < len(bv.bits)-1; i++ {
// bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bits[i]))
// }
// }
// func (bv *BitVector) Rank(k int32, f bool) int32 {
// m, s := bv.bits[k>>6], bv.preSum[k>>6]
// res := s + int32(bits.OnesCount64(m&((1<<(k&63))-1)))
// if f {
// return res
// }
// return k - res
// }