結果
問題 | No.1332 Range Nearest Query |
ユーザー |
|
提出日時 | 2024-04-15 00:02:32 |
言語 | Go (1.23.4) |
結果 |
WA
|
実行時間 | - |
コード長 | 18,856 bytes |
コンパイル時間 | 12,736 ms |
コンパイル使用メモリ | 238,260 KB |
実行使用メモリ | 18,816 KB |
最終ジャッジ日時 | 2024-10-04 06:10:28 |
合計ジャッジ時間 | 23,423 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 5 WA * 5 RE * 38 |
ソースコード
// 维护区间贡献的 Wavelet Matrix// !注意查询区间贡献时, 异或无效// CountRange(start, end, a, b, xor) - 区间 [start, end) 中值在 [a, b) 之间的数的个数和这些数的和.// CountPrefix(start, end, x, xor) - 区间 [start, end) 中值在 [0, x) 之间的数的个数和这些数的和.// SumRange(start, end, k1, k2, xor) - 区间 [start, end) 中第 k1 到第 k2 小的数的和.// SumPrefix(start, end, k, xor) - 区间 [start, end) 中值小于等于 k 的数的和.// Kth(start, end, k, xor) - 区间 [start, end) 中第 k(k>=0) 小的数.// KthValueAndSum(start, end, k, xor) - 区间 [start, end) 中第 k 小的数(0-indexed) 和前 k 小的数的和(不包括这个数).// Median(start, end, upper, xor) - 区间 [start, end) 中的中位数.// SumAll(start, end) - 区间 [start, end) 中所有数的和.// Floor(start, end, x, xor) - 区间 [start, end) 中值小于等于 x 的最大值// Ceil(start, end, x, xor) - 区间 [start, end) 中值大于等于 x 的最小值// MaxRight(start, end, xor, check) - 返回使得 check(count,prefixSum) 为 true 的最大(count,prefixSum), 其中prefixSum为[0,val)内的数的和.package mainimport ("bufio""fmt""math/bits""os""sort")func main() {// demo()// 区间最短距离和()// abc281_e()yuki1332()}func demo() {nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}wm := NewWaveletMatrixWithSum(nums, nums, -1, false)fmt.Println(wm.CountRange(0, 3, 3, 7, 0))fmt.Println(wm.KthValueAndSum(1, 5, 0, 0))nums = []int{1, 2, 3, 4, 5}wm = NewWaveletMatrixWithSum(nums, nums, -1, false)fmt.Println(wm.KthValueAndSum(0, 5, 2, 0))fmt.Println(wm.KthValueAndSum(0, 6, 5, 0))fmt.Println(wm.SumAll(0, 4))fmt.Println(wm.Floor(0, 5, 4, 0))}func abc281_e() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n, m, k int32fmt.Fscan(in, &n, &m, &k)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])}wm := NewWaveletMatrixWithSum(nums, nums, -1, false)for i := int32(0); i < n-m+1; i++ {_, res := wm.KthValueAndSum(i, i+m, k, 0)fmt.Fprintln(out, res)}}// 子数组所有数到中位数的距离和(区间中位数距离和)// https://yukicoder.me/problems/no/924// n,q<=2e5// -1e9 <= nums[i] <= 1e9// 给定n个查询[l,r]// !求区间[l,r]中位数到区间[l,r]中每个数的距离之和// !也就求函数 f(x)= ∑|nums[i]-x| (l<=i<=right) 的最小值// !区间中位数func 区间最短距离和() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n, q intfmt.Fscan(in, &n, &q)OFFSET := int(1e9 + 10)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])nums[i] += OFFSET}preSum := make([]int, n+1)for i := 0; i < n; i++ {preSum[i+1] = preSum[i] + nums[i]}wm := NewWaveletMatrixWithSum(nums, nums, -1, false)for i := 0; i < q; i++ {var start, end int32fmt.Fscan(in, &start, &end)start--n := end - startlowerCount := n / 2ceilCount := n - lowerCountmid, lowerSum := wm.KthValueAndSum(start, end, lowerCount, 0)allSum := preSum[end] - preSum[start]ceilSum := allSum - lowerSumres := 0res += mid*int(lowerCount) - lowerSumres += ceilSum - mid*int(ceilCount)fmt.Fprintln(out, res)}}// LL(N);// VEC(ll, X, N);// Wavelet_Matrix<ll, true> WM(X);// LL(Q);//// FOR(Q) {// LL(l, r, x);// --l;// ll ANS = infty<ll>;// ll n = WM.count(l, r, 0, x);// if (n > 0) chmin(ANS, abs(x - WM.kth(l, r, n - 1)));// if (n < r - l) chmin(ANS, abs(x - WM.kth(l, r, n)));// print(ANS);// }//// No.1332 Range Nearest Query// https://yukicoder.me/problems/no/1332// 区间最近值(无序数组区间前驱后继)查询// !对每个查询[left,right,x],输出左闭右开区间内的数到给定数x的最小距离// n<=3e5 q<=1e5 nums[i]<=1e9func yuki1332() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n intfmt.Fscan(in, &n)nums := make([]int, n)for i := 0; i < n; i++ {fmt.Fscan(in, &nums[i])}M := NewWaveletMatrixWithSum(nums, nil, -1, true)var q intfmt.Fscan(in, &q)for i := 0; i < q; i++ {var start, end int32var x intfmt.Fscan(in, &start, &end, &x)start--res := INFfloor := M.Floor(start, end, x, 0) // 小于等于x的最大值if floor != -INF {res = min(res, abs(floor-x))}ceiling := M.Ceil(start, end, x, 0) // 大于等于x的最小值if ceiling != INF {res = min(res, abs(ceiling-x))}fmt.Fprintln(out, res)}}const INF WmValue = 1e18type WmValue = inttype WmSum = intfunc (*WaveletMatrixWithSum) e() WmSum { return 0 }func (*WaveletMatrixWithSum) op(a, b WmSum) WmSum { return a + b }func (*WaveletMatrixWithSum) inv(a WmSum) WmSum { return -a }type WaveletMatrixWithSum struct {n, log int32mid []int32bv []*BitVectorkey []WmValuesetLog boolpresum [][]WmSumcompress bool}// nums: 数组元素.// sumData: 和数据.nil表示不需要和数据.// log: 如果需要支持异或查询, 需要传入log.-1表示默认.// compress: 是否对nums进行离散化.func NewWaveletMatrixWithSum(nums []WmValue, sumData []WmSum, log int32, compress bool) *WaveletMatrixWithSum {wm := &WaveletMatrixWithSum{}wm.build(nums, sumData, log, compress)return wm}func (wm *WaveletMatrixWithSum) build(nums []WmValue, sumData []WmSum, log int32, compress bool) {numsCopy := append(nums[:0:0], nums...)sumDataCopy := append(sumData[:0:0], sumData...)wm.n = int32(len(numsCopy))wm.log = logwm.compress = compresswm.setLog = log != -1if wm.n == 0 {wm.log = 0return}makeSum := len(sumData) > 0if compress {if wm.setLog {panic("compress and log should not be set at the same time")}wm.key = make([]WmValue, 0, wm.n)order := wm.argSort(numsCopy)for _, i := range order {if len(wm.key) == 0 || wm.key[len(wm.key)-1] != numsCopy[i] {wm.key = append(wm.key, numsCopy[i])}numsCopy[i] = WmValue(len(wm.key) - 1)}}if wm.log == -1 {tmp := wm.maxs(numsCopy)if tmp < 1 {tmp = 1}wm.log = int32(bits.Len(uint(tmp)))}wm.mid = make([]int32, wm.log)wm.bv = make([]*BitVector, wm.log)for i := range wm.bv {wm.bv[i] = NewBitVector(wm.n)}if makeSum {wm.presum = make([][]WmSum, 1+wm.log)for i := range wm.presum {sums := make([]WmSum, wm.n+1)for j := range sums {sums[j] = wm.e()}wm.presum[i] = sums}}if len(sumDataCopy) == 0 {sumDataCopy = make([]WmSum, len(numsCopy))}A, S := numsCopy, sumDataCopyA0, A1 := make([]WmValue, wm.n), make([]WmValue, wm.n)S0, S1 := make([]WmSum, wm.n), make([]WmSum, wm.n)for d := wm.log - 1; d >= -1; d-- {p0, p1 := int32(0), int32(0)if makeSum {tmp := wm.presum[d+1]for i := int32(0); i < wm.n; i++ {tmp[i+1] = wm.op(tmp[i], S[i])}}if d == -1 {break}for i := int32(0); i < wm.n; i++ {f := (A[i] >> d & 1) == 1if !f {if makeSum {S0[p0] = S[i]}A0[p0] = A[i]p0++} else {if makeSum {S1[p1] = S[i]}wm.bv[d].Set(i)A1[p1] = A[i]p1++}}wm.mid[d] = p0wm.bv[d].Build()A, A0 = A0, AS, S0 = S0, Sfor i := int32(0); i < p1; i++ {A[p0+i] = A1[i]S[p0+i] = S1[i]}}}// 区间 [start, end) 中范围在 [a, b) 中的元素的个数.func (wm *WaveletMatrixWithSum) CountRange(start, end int32, a, b WmValue, xorVal WmValue) int32 {return wm.CountPrefix(start, end, b, xorVal) - wm.CountPrefix(start, end, a, xorVal)}func (wm *WaveletMatrixWithSum) SumRange(start, end, k1, k2 int32, xorVal WmValue) WmSum {if k1 >= k2 {return wm.e()}add := wm.SumPrefix(start, end, k2, xorVal)sub := wm.SumPrefix(start, end, k1, xorVal)return wm.op(add, wm.inv(sub))}// 返回区间 [start, end) 中 范围在 [0, x) 中的元素的个数.func (wm *WaveletMatrixWithSum) CountPrefix(start, end int32, x WmValue, xor WmValue) int32 {if xor != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if wm.compress {x = wm.lowerBound(wm.key, x)}if x == 0 {return 0}if x >= 1<<wm.log {return end - start}count := int32(0)for d := wm.log - 1; d >= 0; d-- {add := (x>>d)&1 == 1f := (xor>>d)&1 == 1l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)kf := int32(0)if f {kf = (end - start - r0 + l0)} else {kf = (r0 - l0)}if add {count += kfif f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}} else {if !f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}}}return count}// 返回区间 [start, end) 中 [0, k) 的和.func (wm *WaveletMatrixWithSum) SumPrefix(start, end, k int32, xor WmValue) WmSum {_, sum := wm.KthValueAndSum(start, end, k, xor)return sum}// [start, end)区间内第k(k>=0)小的元素.func (wm *WaveletMatrixWithSum) Kth(start, end, k int32, xorVal WmValue) WmValue {if xorVal != 0 {if !wm.setLog {panic("log should be set")}}count := int32(0)res := WmValue(0)for d := wm.log - 1; d >= 0; d-- {f := (xorVal>>d)&1 == 1l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)var c int32if f {c = (end - start) - (r0 - l0)} else {c = r0 - l0}if count+c > k {if !f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}} else {count += cres |= 1 << dif f {start += wm.mid[d] - l0end += wm.mid[d] - r0} else {start, end = l0, r0}}}if wm.compress {res = wm.key[res]}return res}// 返回区间 [start, end) 中的 (第k小的元素, 前k个元素(不包括第k小的元素) 的 op 的结果).// 如果k >= end-start, 返回 (-1, 区间 op 的结果).func (wm *WaveletMatrixWithSum) KthValueAndSum(start, end, k int32, xorVal WmValue) (WmValue, WmSum) {if start >= end {return -1, wm.e()}if k >= end-start {return -1, wm.SumAll(start, end)}if xorVal != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if len(wm.presum) == 0 {panic("sumData is not provided")}count := int32(0)sum := wm.e()res := WmValue(0)for d := wm.log - 1; d >= 0; d-- {f := (xorVal>>d)&1 == 1l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)c := int32(0)if f {c = (end - start) - (r0 - l0)} else {c = r0 - l0}if count+c > k {if !f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}} else {var s WmSumif f {s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)} else {s = wm._get(d, l0, r0)}count += csum = wm.op(sum, s)res |= 1 << dif !f {start += wm.mid[d] - l0end += wm.mid[d] - r0} else {start, end = l0, r0}}}sum = wm.op(sum, wm._get(0, start, start+k-count))if wm.compress {res = wm.key[res]}return res, sum}// upper: 向上取中位数还是向下取中位数.func (wm *WaveletMatrixWithSum) Median(start, end int32, upper bool, xorVal WmValue) WmValue {n := end - startvar k int32if upper {k = n / 2} else {k = (n - 1) / 2}return wm.Kth(start, end, k, xorVal)}func (wm *WaveletMatrixWithSum) SumAll(start, end int32) WmSum {return wm._get(wm.log, start, end)}// 使得predicate(count, sum)为true的最大的(count, sum).func (wm *WaveletMatrixWithSum) MaxRight(predicate func(int32, WmSum) bool, start, end int32, xorVal WmValue) (int32, WmSum) {if xorVal != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if start == end {return end - start, wm.e()}if s := wm._get(wm.log, start, end); predicate(end-start, s) {return end - start, s}count := int32(0)sum := wm.e()for d := wm.log - 1; d >= 0; d-- {f := (xorVal>>d)&1 == 1l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)c := int32(0)if f {c = (end - start) - (r0 - l0)} else {c = (r0 - l0)}var s WmSumif f {s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)} else {s = wm._get(d, l0, r0)}if tmp := wm.op(sum, s); predicate(count+c, tmp) {count += csum = tmpif f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}} else {if !f {start, end = l0, r0} else {start += wm.mid[d] - l0end += wm.mid[d] - r0}}}k := wm.binarySearch(func(k int32) bool {return predicate(count+k, wm.op(sum, wm._get(0, start, start+k)))}, 0, end-start)count += ksum = wm.op(sum, wm._get(0, start, start+k))return count, sum}// [start, end) 中小于等于 x 的数中最大的数//// 如果不存在则返回-INFfunc (wm *WaveletMatrixWithSum) Floor(start, end int32, x WmValue, xor WmValue) WmValue {less := wm.CountPrefix(start, end, x, xor)if less == 0 {return -INF}res := wm.Kth(start, end, less-1, xor)return res}// [start, end) 中大于等于 x 的数中最小的数//// 如果不存在则返回INFfunc (wm *WaveletMatrixWithSum) Ceil(start, end int32, x WmValue, xor WmValue) int {less := wm.CountPrefix(start, end, x, xor)if less == end-start {return INF}res := wm.Kth(start, end, less, xor)return res}func (wm *WaveletMatrixWithSum) CountSegments(segments [][2]int32, a, b WmValue, xorVal WmValue) int32 {res := int32(0)for _, seg := range segments {res += wm.CountRange(seg[0], seg[1], a, b, xorVal)}return res}func (wm *WaveletMatrixWithSum) KthSegments(segments [][2]int32, k int32, xorVal WmValue) WmValue {totalLen := int32(0)for _, seg := range segments {totalLen += seg[1] - seg[0]}count := int32(0)res := WmValue(0)for d := wm.log - 1; d >= 0; d-- {f := (xorVal>>d)&1 == 1c := int32(0)for _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)if f {c += (seg[1] - seg[0]) - (r0 - l0)} else {c += r0 - l0}}if count+c > k {for _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)if !f {seg[0], seg[1] = l0, r0} else {seg[0] += wm.mid[d] - l0seg[1] += wm.mid[d] - r0}}} else {count += cres |= 1 << dfor _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)if f {seg[0], seg[1] = l0, r0} else {seg[0] += wm.mid[d] - l0seg[1] += wm.mid[d] - r0}}}}if wm.compress {res = wm.key[res]}return res}func (wm *WaveletMatrixWithSum) KthValueAndSumSegments(segments [][2]int32, k int32, xorVal WmValue) (WmValue, WmSum) {if xorVal != 0 {if !wm.setLog {panic("log should be set when xor is used")}}totalLen := int32(0)for _, seg := range segments {totalLen += seg[1] - seg[0]}if k >= totalLen {return -1, wm.SumAllSegments(segments)}count := int32(0)sum := wm.e()res := WmValue(0)for d := wm.log - 1; d >= 0; d-- {f := (xorVal>>d)&1 == 1c := int32(0)for _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)if f {c += (seg[1] - seg[0]) - (r0 - l0)} else {c += r0 - l0}}if count+c > k {for _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)if !f {seg[0], seg[1] = l0, r0} else {seg[0] += wm.mid[d] - l0seg[1] += wm.mid[d] - r0}}} else {count += cres |= 1 << dfor _, seg := range segments {l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)var s WmSumif f {s = wm._get(d, seg[0]+wm.mid[d]-l0, seg[1]+wm.mid[d]-r0)} else {s = wm._get(d, l0, r0)}sum = wm.op(sum, s)if !f {seg[0] += wm.mid[d] - l0seg[1] += wm.mid[d] - r0} else {seg[0], seg[1] = l0, r0}}}}for _, seg := range segments {t := min32(seg[1]-seg[0], k-count)sum = wm.op(sum, wm._get(0, seg[0], seg[0]+t))count += t}if wm.compress {res = wm.key[res]}return res, sum}func (wm *WaveletMatrixWithSum) MedianSegments(segments [][2]int32, upper bool, xorVal WmValue) WmValue {n := int32(0)for _, seg := range segments {n += seg[1] - seg[0]}var k int32if upper {k = n / 2} else {k = (n - 1) / 2}return wm.KthSegments(segments, k, xorVal)}func (wm *WaveletMatrixWithSum) SumAllSegments(segments [][2]int32) WmSum {sum := wm.e()for _, seg := range segments {sum = wm.op(sum, wm._get(wm.log, seg[0], seg[1]))}return sum}func (wm *WaveletMatrixWithSum) _prefixSumSegments(segments [][2]int32, k int32, xor WmValue) WmSum {_, sum := wm.KthValueAndSumSegments(segments, k, xor)return sum}func (wm *WaveletMatrixWithSum) _get(d, l, r int32) WmSum {return wm.op(wm.inv(wm.presum[d][l]), wm.presum[d][r])}func (wm *WaveletMatrixWithSum) argSort(nums []WmValue) []int32 {order := make([]int32, len(nums))for i := range order {order[i] = int32(i)}sort.Slice(order, func(i, j int) bool { return nums[order[i]] < nums[order[j]] })return order}func (wm *WaveletMatrixWithSum) maxs(nums []WmValue) WmValue {res := nums[0]for _, v := range nums {if v > res {res = v}}return res}func (wm *WaveletMatrixWithSum) lowerBound(nums []WmValue, target WmValue) WmValue {left, right := int32(0), int32(len(nums)-1)for left <= right {mid := (left + right) >> 1if nums[mid] < target {left = mid + 1} else {right = mid - 1}}return WmValue(left)}func (wm *WaveletMatrixWithSum) binarySearch(f func(int32) bool, ok, ng int32) int32 {for abs32(ok-ng) > 1 {x := (ok + ng) >> 1if f(x) {ok = x} else {ng = x}}return ok}func abs(x int) int {if x < 0 {return -x}return x}func abs32(x int32) int32 {if x < 0 {return -x}return x}func min32(a, b int32) int32 {if a < b {return a}return b}func min(a, b int) int {if a < b {return a}return b}type BitVector struct {bits []uint64preSum []int32}func NewBitVector(n int32) *BitVector {return &BitVector{bits: make([]uint64, n>>6+1), preSum: make([]int32, n>>6+1)}}func (bv *BitVector) Set(i int32) {bv.bits[i>>6] |= 1 << (i & 63)}func (bv *BitVector) Build() {for i := 0; i < len(bv.bits)-1; i++ {bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bits[i]))}}func (bv *BitVector) Rank(k int32, f bool) int32 {m, s := bv.bits[k>>6], bv.preSum[k>>6]res := s + int32(bits.OnesCount64(m&((1<<(k&63))-1)))if f {return res}return k - res}