結果
問題 | No.924 紲星 |
ユーザー |
|
提出日時 | 2024-05-03 17:08:09 |
言語 | Go (1.23.4) |
結果 |
AC
|
実行時間 | 652 ms / 4,000 ms |
コード長 | 18,256 bytes |
コンパイル時間 | 14,374 ms |
コンパイル使用メモリ | 217,944 KB |
実行使用メモリ | 68,200 KB |
最終ジャッジ日時 | 2024-11-24 20:04:17 |
合計ジャッジ時間 | 22,278 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 16 |
ソースコード
// 维护区间贡献的 Wavelet Matrix// !注意查询区间贡献时, 异或无效//// api:// !RangeCountAndSum(start, end, a, b, xor) - 区间 [start, end) 中值在 [a, b) 之间的数的个数和这些数的和.// !KthValueAndSum(start, end, k, xor) - 区间 [start, end) 中第 k 小的数和前 k 小的数的和.//// Kth(start, end, k, xor) - 区间 [start, end) 中第 k 小的数.// Median(start, end, upper, xor) - 区间 [start, end) 中的中位数.// Floor(start, end, x, xor) - 区间 [start, end) 中值小于等于 x 的最大值// Ceil(start, end, x, xor) - 区间 [start, end) 中值大于等于 x 的最小值//// SumAll(start, end) - 区间 [start, end) 中所有数的和.// SumRange(start, end, a, b, xor) - 区间 [start, end) 中值在 [a, b) 之间的数的和.// SumSlice(start, end, k1, k2, xor) - 区间 [start, end) 中第 k1 到第 k2 小的数的和.//// MaxRight(start, end, xor, check) - 返回使得 check(count,prefixSum) 为 true 的最大(count,prefixSum), 其中prefixSum为[0,val)内的数的和.package mainimport ("bufio""fmt""math/bits""os""sort")func main() {// demo()区间最短距离和()// abc281_e()// yuki919()// yuki1332()// yuki2065()}func demo() {nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}wm := NewWaveletMatrixWithSum(nums, nums, -1, true)fmt.Println(wm.RangeCountAndSum(0, 3, 3, 7, 0))fmt.Println(wm.KthValueAndSum(1, 5, 0, 0))nums = []int{1, 2, 3, 4, 5}wm = NewWaveletMatrixWithSum(nums, nums, -1, false)fmt.Println(wm.KthValueAndSum(0, 5, 2, 0))fmt.Println(wm.KthValueAndSum(0, 6, 5, 0))fmt.Println(wm.SumAll(0, 4))fmt.Println(wm.Floor(0, 5, 4, 0))fmt.Println(wm.Kth(0, 5, 1, 0))}func abc281_e() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n, m, k int32fmt.Fscan(in, &n, &m, &k)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])}wm := NewWaveletMatrixWithSum(nums, nums, -1, false)for i := int32(0); i < n-m+1; i++ {_, res := wm.KthValueAndSum(i, i+m, k, 0)fmt.Fprintln(out, res)}}// 子数组所有数到中位数的距离和(区间中位数距离和)// https://yukicoder.me/problems/no/924// n,q<=2e5// -1e9 <= nums[i] <= 1e9// 给定n个查询[l,r]// !求区间[l,r]中位数到区间[l,r]中每个数的距离之和// !也就求函数 f(x)= ∑|nums[i]-x| (l<=i<=right) 的最小值// !区间中位数func 区间最短距离和() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n, q intfmt.Fscan(in, &n, &q)OFFSET := int(1e9 + 10)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])nums[i] += OFFSET}wm := NewWaveletMatrixWithSum(nums, nums, -1, false)for i := 0; i < q; i++ {var start, end int32fmt.Fscan(in, &start, &end)start--n := end - startlowerCount := n / 2ceilCount := n - lowerCountmid, lowerSum := wm.KthValueAndSum(start, end, lowerCount, 0)allSum := wm.SumAll(start, end)ceilSum := allSum - lowerSumres := 0res += mid*int(lowerCount) - lowerSumres += ceilSum - mid*int(ceilCount)fmt.Fprintln(out, res)}}// No.919 You Are A Project Manager (区间中位数+前后缀分解)// https://yukicoder.me/problems/no/919// 给定一个数组nums,需要分为若干组.// 初始时选择分组大小k,每次分组需要从最左边或最右边取k个数分为一组,// 每个组的的得分为k*组内中位数(这里的中位数向下).// 求最大化总得分.//// n<=1e4,-1e9<=nums[i]<1e9func yuki919() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n int32fmt.Fscan(in, &n)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])}wm := NewWaveletMatrixWithSum(nums, nil, -1, true)median := func(start, end int32) int {return wm.Median(start, end, false, 0)}res := -INFfor size := int32(1); size <= n; size++ {count := n / sizeleftScore, rightScore := make([]int, count), make([]int, count)for i := int32(0); i < count; i++ {leftScore[i] = median(i*size, i*size+size)rightScore[i] = median(n-i*size-size, n-i*size)}leftPresum, rightPresum := make([]int, count+1), make([]int, count+1)for i := int32(0); i < count; i++ {leftPresum[i+1] = leftPresum[i] + leftScore[i]rightPresum[i+1] = rightPresum[i] + rightScore[i]}for i := int32(0); i < count; i++ {leftPresum[i+1] = max(leftPresum[i+1], leftPresum[i])rightPresum[i+1] = max(rightPresum[i+1], rightPresum[i])}for i := int32(0); i <= count; i++ {res = max(res, int(size)*(leftPresum[i]+rightPresum[count-i]))}}fmt.Fprintln(out, res)}// No.1332 Range Nearest Query// https://yukicoder.me/problems/no/1332// 区间最近值(无序数组区间前驱后继)查询// !对每个查询[left,right,x],输出左闭右开区间内的数到给定数x的最小距离// n<=3e5 q<=1e5 nums[i]<=1e9func yuki1332() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n intfmt.Fscan(in, &n)nums := make([]int, n)for i := 0; i < n; i++ {fmt.Fscan(in, &nums[i])}M := NewWaveletMatrixWithSum(nums, nil, -1, true)var q intfmt.Fscan(in, &q)for i := 0; i < q; i++ {var start, end int32var x intfmt.Fscan(in, &start, &end, &x)start--res := INFfloor := M.Floor(start, end, x, 0) // 小于等于x的最大值if floor != -INF {res = min(res, abs(floor-x))}ceiling := M.Ceil(start, end, x, 0) // 大于等于x的最小值if ceiling != INF {res = min(res, abs(ceiling-x))}fmt.Fprintln(out, res)}}// No.2065 Sum of Min (区间chmin之和)// https://yukicoder.me/problems/no/2065// 小于x数的和+(end-start-小于x数的个数)*xfunc yuki2065() {in := bufio.NewReader(os.Stdin)out := bufio.NewWriter(os.Stdout)defer out.Flush()var n, q int32fmt.Fscan(in, &n, &q)nums := make([]int, n)for i := range nums {fmt.Fscan(in, &nums[i])}M := NewWaveletMatrixWithSum(nums, nums, -1, true)for i := int32(0); i < q; i++ {var start, end int32var x intfmt.Fscan(in, &start, &end, &x)start--lessCount, _ := M.RangeCountAndSum(start, end, 0, x, 0) // 小于x的数的个数lessSum := M.SumSlice(start, end, 0, lessCount, 0) // 小于x的数的和res := lessSum + int((end-start-lessCount))*xfmt.Fprintln(out, res)}}const INF WmValue = 1e18type WmValue = inttype WmSum = intfunc (*WaveletMatrixWithSum) e() WmSum { return 0 }func (*WaveletMatrixWithSum) op(a, b WmSum) WmSum { return a + b }func (*WaveletMatrixWithSum) inv(a WmSum) WmSum { return -a }type WaveletMatrixWithSum struct {n, log int32setLog boolcompress booluseSum boolmid []int32bv []*BitVectorkey []WmValuepresum [][]WmSum}// nums: 数组元素.// sumData: 和数据,nil表示不需要和数据.// log: 如果需要支持异或查询则需要传入log,-1表示默认.// compress: 是否对nums进行离散化(值域较大(1e9)时可以离散化加速).func NewWaveletMatrixWithSum(nums []WmValue, sumData []WmSum, log int32, compress bool) *WaveletMatrixWithSum {wm := &WaveletMatrixWithSum{}wm.build(nums, sumData, log, compress)return wm}func (wm *WaveletMatrixWithSum) build(nums []WmValue, sumData []WmSum, log int32, compress bool) {numsCopy := append(nums[:0:0], nums...)sumDataCopy := append(sumData[:0:0], sumData...)wm.n = int32(len(numsCopy))wm.log = logwm.setLog = log != -1wm.compress = compresswm.useSum = len(sumData) > 0if wm.n == 0 {wm.log = 0wm.presum = [][]WmSum{{wm.e()}}return}if compress {if wm.setLog {panic("compress and log should not be set at the same time")}wm.key = make([]WmValue, 0, wm.n)order := wm._argSort(numsCopy)for _, i := range order {if len(wm.key) == 0 || wm.key[len(wm.key)-1] != numsCopy[i] {wm.key = append(wm.key, numsCopy[i])}numsCopy[i] = WmValue(len(wm.key) - 1)}wm.key = wm.key[:len(wm.key):len(wm.key)]}if wm.log == -1 {tmp := wm._maxs(numsCopy)if tmp < 1 {tmp = 1}wm.log = int32(bits.Len(uint(tmp)))}wm.mid = make([]int32, wm.log)wm.bv = make([]*BitVector, wm.log)for i := range wm.bv {wm.bv[i] = NewBitVector(wm.n)}if wm.useSum {wm.presum = make([][]WmSum, 1+wm.log)for i := range wm.presum {sums := make([]WmSum, wm.n+1)for j := range sums {sums[j] = wm.e()}wm.presum[i] = sums}}if len(sumDataCopy) == 0 {sumDataCopy = make([]WmSum, len(numsCopy))}A, S := numsCopy, sumDataCopyA0, A1 := make([]WmValue, wm.n), make([]WmValue, wm.n)S0, S1 := make([]WmSum, wm.n), make([]WmSum, wm.n)for d := wm.log - 1; d >= -1; d-- {p0, p1 := int32(0), int32(0)if wm.useSum {tmp := wm.presum[d+1]for i := int32(0); i < wm.n; i++ {tmp[i+1] = wm.op(tmp[i], S[i])}}if d == -1 {break}for i := int32(0); i < wm.n; i++ {f := (A[i] >> d & 1) == 1if !f {if wm.useSum {S0[p0] = S[i]}A0[p0] = A[i]p0++} else {if wm.useSum {S1[p1] = S[i]}wm.bv[d].Set(i)A1[p1] = A[i]p1++}}wm.mid[d] = p0wm.bv[d].Build()A, A0 = A0, AS, S0 = S0, Sfor i := int32(0); i < p1; i++ {A[p0+i] = A1[i]S[p0+i] = S1[i]}}}// 返回区间 [start, end) 中 值在 [a, b) 中的元素个数以及这些元素的和.func (wm *WaveletMatrixWithSum) RangeCountAndSum(start, end int32, a, b WmValue, xorValue WmValue) (int32, WmSum) {if xorValue != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end || a >= b {return 0, wm.e()}if wm.compress {a = wm._lowerBound(wm.key, a)b = wm._lowerBound(wm.key, b)}count, sum := int32(0), wm.e()var dfs func(d, l, r int32, lx, rx WmValue)dfs = func(d, l, r int32, lx, rx WmValue) {if rx <= a || b <= lx {return}if a <= lx && rx <= b {count += r - lif wm.useSum {sum = wm.op(sum, wm._get(d, l, r))}return}d--mx := (lx + rx) >> 1l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0if xorValue>>d&1 == 1 {l0, l1 = l1, l0r0, r1 = r1, r0}dfs(d, l0, r0, lx, mx)dfs(d, l1, r1, mx, rx)}dfs(wm.log, start, end, 0, 1<<wm.log)return count, sum}// 返回区间 [start, end) 中的 (第k小的元素, 前k个元素(不包括第k小的元素) 的 op 的结果).// 如果k >= end-start, 返回 (-1, 区间 op 的结果).func (wm *WaveletMatrixWithSum) KthValueAndSum(start, end, k int32, xorVal WmValue) (WmValue, WmSum) {if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end {return -1, wm.e()}if k >= end-start {return -1, wm.SumAll(start, end)}if xorVal != 0 {if !wm.setLog {panic("log should be set when xor is used")}}sum, val := wm.e(), WmValue(0)for d := wm.log - 1; d >= 0; d-- {l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)l1, r1 := start+wm.mid[d]-l0, end+wm.mid[d]-r0if (xorVal>>d)&1 == 1 {l0, l1 = l1, l0r0, r1 = r1, r0}if k < r0-l0 {start, end = l0, r0} else {k -= r0 - l0val |= 1 << dstart, end = l1, r1if wm.useSum {sum = wm.op(sum, wm._get(d, l0, r0))}}}if wm.useSum {sum = wm.op(sum, wm._get(0, start, start+k))}if wm.compress {val = wm.key[val]}return val, sum}// [start, end)区间内第k(k>=0)小的元素.func (wm *WaveletMatrixWithSum) Kth(start, end, k int32, xorVal WmValue) WmValue {if k < 0 {k = 0}if n := end - start - 1; k > n {k = n}v, _ := wm.KthValueAndSum(start, end, k, xorVal)return v}// upper: 向上取中位数还是向下取中位数.func (wm *WaveletMatrixWithSum) Median(start, end int32, upper bool, xorVal WmValue) WmValue {n := end - startvar k int32if upper {k = n >> 1} else {k = (n - 1) >> 1}return wm.Kth(start, end, k, xorVal)}// [start, end) 中小于等于 x 的数中最大的数.//// 如果不存在则返回-INF.func (wm *WaveletMatrixWithSum) Floor(start, end int32, x WmValue, xor WmValue) WmValue {if xor != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end {return -INF}res := -INFx++if wm.compress {x = wm._lowerBound(wm.key, x)}var dfs func(d, l, r int32, lx, rx WmValue)dfs = func(d, l, r int32, lx, rx WmValue) {if rx-1 <= res || l == r || x <= lx {return}if d == 0 {res = max(res, lx)return}d--mx := (lx + rx) >> 1l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0if xor>>d&1 == 1 {l0, l1 = l1, l0r0, r1 = r1, r0}dfs(d, l1, r1, mx, rx)dfs(d, l0, r0, lx, mx)}dfs(wm.log, start, end, 0, 1<<wm.log)if wm.compress && res != -INF {res = wm.key[res]}return res}// [start, end) 中大于等于 x 的数中最小的数//// 如果不存在则返回INFfunc (wm *WaveletMatrixWithSum) Ceil(start, end int32, x WmValue, xor WmValue) int {if xor != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end {return INF}res := INFif wm.compress {x = wm._lowerBound(wm.key, x)}var dfs func(d, l, r int32, lx, rx WmValue)dfs = func(d, l, r int32, lx, rx WmValue) {if res <= lx || l == r || x <= lx {return}if d == 0 {res = min(res, rx)return}d--mx := (lx + rx) >> 1l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0if xor>>d&1 == 1 {l0, l1 = l1, l0r0, r1 = r1, r0}dfs(d, l0, r0, lx, mx)dfs(d, l1, r1, mx, rx)}dfs(wm.log, start, end, 0, 1<<wm.log)if wm.compress && res < INF {res = wm.key[res]}return res}// 返回区间 [start, end) 中 范围在 [a, b) 中的元素的和.func (wm *WaveletMatrixWithSum) SumRange(start, end int32, a, b WmValue, xorVal WmValue) WmSum {if !wm.useSum {panic("sum data must be provided")}if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end || a >= b {return wm.e()}_, sum := wm.RangeCountAndSum(start, end, a, b, xorVal)return sum}// 返回区间 [start, end) 中 排名在 [k1, k2) 中的元素的和.func (wm *WaveletMatrixWithSum) SumSlice(start, end, k1, k2 int32, xorVal WmValue) WmSum {if !wm.useSum {panic("sum data must be provided")}if k1 < 0 {k1 = 0}if k2 > end-start {k2 = end - start}if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end || k1 >= k2 {return wm.e()}_, sum1 := wm.KthValueAndSum(start, end, k1, xorVal)_, sum2 := wm.KthValueAndSum(start, end, k2, xorVal)return wm.op(sum2, wm.inv(sum1))}func (wm *WaveletMatrixWithSum) SumAll(start, end int32) WmSum {if start < 0 {start = 0}if end > wm.n {end = wm.n}if start >= end {return wm.e()}return wm._get(wm.log, start, end)}// 使得predicate(count, sum)为true的最大的(count, sum).func (wm *WaveletMatrixWithSum) MaxRight(predicate func(int32, WmSum) bool, start, end int32, xorVal WmValue) (int32, WmSum) {if xorVal != 0 {if !wm.setLog {panic("log should be set when xor is used")}}if start >= end {return 0, wm.e()}if s := wm._get(wm.log, start, end); predicate(end-start, s) {return end - start, s}count, sum := int32(0), wm.e()for d := wm.log - 1; d >= 0; d-- {l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)l1, r1 := start+wm.mid[d]-l0, end+wm.mid[d]-r0if xorVal>>d&1 == 1 {l0, l1 = l1, l0r0, r1 = r1, r0}if s := wm.op(sum, wm._get(d, l0, r0)); predicate(count+r0-l0, s) {count += r0 - l0sum = sstart, end = l1, r1} else {start, end = l0, r0}}k := wm._binarySearch(func(k int32) bool {return predicate(count+k, wm.op(sum, wm._get(0, start, start+k)))}, 0, end-start)count += ksum = wm.op(sum, wm._get(0, start, start+k))return count, sum}func (wm *WaveletMatrixWithSum) _get(d, l, r int32) WmSum {if wm.useSum {return wm.op(wm.presum[d][r], wm.inv(wm.presum[d][l]))}return wm.e()}func (wm *WaveletMatrixWithSum) _argSort(nums []WmValue) []int32 {order := make([]int32, len(nums))for i := range order {order[i] = int32(i)}sort.Slice(order, func(i, j int) bool { return nums[order[i]] < nums[order[j]] })return order}func (wm *WaveletMatrixWithSum) _maxs(nums []WmValue) WmValue {res := nums[0]for _, v := range nums {if v > res {res = v}}return res}func (wm *WaveletMatrixWithSum) _lowerBound(nums []WmValue, target WmValue) WmValue {left, right := int32(0), int32(len(nums)-1)for left <= right {mid := (left + right) >> 1if nums[mid] < target {left = mid + 1} else {right = mid - 1}}return WmValue(left)}func (wm *WaveletMatrixWithSum) _binarySearch(f func(int32) bool, ok, ng int32) int32 {for abs32(ok-ng) > 1 {x := (ok + ng) >> 1if f(x) {ok = x} else {ng = x}}return ok}type BitVector struct {bits []uint64preSum []int32}func NewBitVector(n int32) *BitVector {return &BitVector{bits: make([]uint64, n>>6+1), preSum: make([]int32, n>>6+1)}}func (bv *BitVector) Set(i int32) {bv.bits[i>>6] |= 1 << (i & 63)}func (bv *BitVector) Build() {for i := 0; i < len(bv.bits)-1; i++ {bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bits[i]))}}func (bv *BitVector) Rank(k int32, f bool) int32 {m, s := bv.bits[k>>6], bv.preSum[k>>6]res := s + int32(bits.OnesCount64(m&((1<<(k&63))-1)))if f {return res}return k - res}func abs(x int) int {if x < 0 {return -x}return x}func abs32(x int32) int32 {if x < 0 {return -x}return x}func min32(a, b int32) int32 {if a < b {return a}return b}func min(a, b int) int {if a < b {return a}return b}func max(a, b int) int {if a > b {return a}return b}