結果

問題 No.924 紲星
ユーザー 草苺奶昔草苺奶昔
提出日時 2024-04-14 20:42:50
言語 Go
(1.22.1)
結果
RE  
実行時間 -
コード長 14,811 bytes
コンパイル時間 9,796 ms
コンパイル使用メモリ 232,556 KB
実行使用メモリ 14,420 KB
最終ジャッジ日時 2024-10-03 23:24:54
合計ジャッジ時間 13,821 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
testcase_02 RE -
testcase_03 RE -
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
testcase_18 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

package main

import (
	"bufio"
	"fmt"
	"math/bits"
	"os"
	"sort"
)

func main() {
	区间最短距离和()
}

func demo() {
	nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
	wm := NewWaveletMatrixWithSum(nums, nil, -1, false)
	fmt.Println(wm.Count(0, 3, 3, 7, 0))
}

func 区间最短距离和() {
	// https://yukicoder.me/problems/no/924
	// n,q<=2e5
	// -1e9 <= nums[i] <= 1e9
	// 给定n个查询[l,r]
	// !求区间[l,r]中位数到区间[l,r]中每个数的距离之和
	// !也就求函数 f(x)= ∑|nums[i]-x| (l<=i<=right) 的最小值
	// !区间中位数
	in := bufio.NewReader(os.Stdin)
	out := bufio.NewWriter(os.Stdout)
	defer out.Flush()

	var n, q int
	fmt.Fscan(in, &n, &q)
	OFFSET := int(1e9 + 10)
	nums := make([]int, n)
	for i := range nums {
		fmt.Fscan(in, &nums[i])
		nums[i] += OFFSET
	}
	preSum := make([]int, n+1)
	for i := range nums {
		preSum[i+1] = preSum[i] + nums[i]
	}

	wm := NewWaveletMatrixWithSum(nums, nil, -1, false)
	for i := 0; i < q; i++ {
		var start, end int32
		fmt.Fscan(in, &start, &end)
		start--

		n := end - start
		lowerCount := n / 2
		ceilCount := n - lowerCount
		mid, lowerSum := wm.KthValueAndSum(start, end, lowerCount, 0)
		_, allSum := wm.KthValueAndSum(start, end, n, 0)
		ceilSum := allSum - lowerSum

		res := 0
		res += mid*int(lowerCount) - lowerSum
		res += ceilSum - mid*int(ceilCount)
		fmt.Fprintln(out, res)
	}
}

type WmValue = int
type WmSum = int

func (*WaveletMatrixWithSum) e() WmSum            { return 0 }
func (*WaveletMatrixWithSum) op(a, b WmSum) WmSum { return a + b }
func (*WaveletMatrixWithSum) inv(a WmSum) WmSum   { return -a }

type WaveletMatrixWithSum struct {
	n, log   int32
	mid      []int32
	bv       []*BitVector
	key      []WmValue
	setLog   bool
	presum   [][]WmSum
	compress bool
}

// nums: 数组元素.
// sumData: 和数据.nil表示不需要和数据.
// log: 如果需要支持异或查询, 需要传入log.-1表示默认.
// compress: 是否对nums进行离散化.
func NewWaveletMatrixWithSum(nums []WmValue, sumData []WmSum, log int32, compress bool) *WaveletMatrixWithSum {
	wm := &WaveletMatrixWithSum{}
	wm.build(nums, sumData, log, compress)
	return wm
}

func (wm *WaveletMatrixWithSum) build(nums []WmValue, sumData []WmSum, log int32, compress bool) {
	numsCopy := append(nums[:0:0], nums...)
	sumDataCopy := append(sumData[:0:0], sumData...)
	if len(sumData) == 0 {
		sumDataCopy = make([]WmSum, len(nums))
	}

	wm.n = int32(len(numsCopy))
	wm.log = log
	wm.compress = compress
	wm.setLog = log != -1
	if wm.n == 0 {
		wm.log = 0
		return
	}
	makeSum := len(sumData) > 0
	if compress {
		if wm.setLog {
			panic("compress and log should not be set at the same time")
		}
		wm.key = make([]WmValue, 0, wm.n)
		order := wm.argSort(numsCopy)
		for _, i := range order {
			if len(wm.key) == 0 || wm.key[len(wm.key)-1] != numsCopy[i] {
				wm.key = append(wm.key, numsCopy[i])
			}
			numsCopy[i] = WmValue(len(wm.key) - 1)
		}
	}
	if wm.log == -1 {
		tmp := wm.maxs(nums)
		if tmp < 1 {
			tmp = 1
		}
		wm.log = int32(bits.Len(uint(tmp)))
	}
	wm.mid = make([]int32, wm.log)
	wm.bv = make([]*BitVector, wm.log)
	for i := range wm.bv {
		wm.bv[i] = NewBitVector(wm.n)
	}
	if makeSum {
		wm.presum = make([][]WmSum, 1+wm.log)
		for i := range wm.presum {
			sums := make([]WmSum, wm.n+1)
			for j := range sums {
				sums[j] = wm.e()
			}
			wm.presum[i] = sums
		}
	}

	A, S := numsCopy, sumDataCopy
	A0, A1 := make([]WmValue, wm.n), make([]WmValue, wm.n)
	S0, S1 := make([]WmSum, wm.n), make([]WmSum, wm.n)
	for d := wm.log - 1; d >= -1; d-- {
		p0, p1 := int32(0), int32(0)
		if makeSum {
			tmp := wm.presum[d+1]
			for i := int32(0); i < wm.n; i++ {
				tmp[i+1] = wm.op(tmp[i], sumData[i])
			}
		}
		if d == -1 {
			break
		}
		for i := int32(0); i < wm.n; i++ {
			f := (nums[i] >> d & 1) == 1
			if !f {
				if makeSum {
					S0[p0] = sumData[i]
				}
				A0[p0] = nums[i]
				p0++
			} else {
				if makeSum {
					S1[p1] = sumData[i]
				}
				wm.bv[d].Set(i)
				A1[p1] = nums[i]
				p1++
			}
		}
		wm.mid[d] = p0
		wm.bv[d].Build()
		A, A0 = A0, A
		S, S0 = S0, S
		for i := int32(0); i < p1; i++ {
			A[p0+i] = A1[i]
			S[p0+i] = S1[i]
		}
	}
}

// [start, end)区间内的元素个数.
func (wm *WaveletMatrixWithSum) Count(start, end int32, a, b WmValue, xorVal WmValue) int32 {
	return wm._prefixCount(start, end, b, xorVal) - wm._prefixCount(start, end, a, xorVal)
}

// [start, end)区间内第k(k>=0)小的元素.
func (wm *WaveletMatrixWithSum) Kth(start, end, k int32, xorVal WmValue) WmValue {
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set")
		}
	}
	count := int32(0)
	res := WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		f := (xorVal>>d)&1 == 1
		l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
		var c int32
		if f {
			c = (end - start) - (r0 - l0)
		} else {
			c = r0 - l0
		}
		if count+c > k {
			if !f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		} else {
			count += c
			res |= 1 << d
			if f {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			} else {
				start, end = l0, r0
			}
		}
	}
	if wm.compress {
		res = wm.key[res]
	}
	return res
}

// 返回区间 [start, end) 中的 (第k小的元素, 前k个元素(不包括第k小的元素) 的 op 的结果).
// 如果k >= end-start, 返回 (-1, 区间 op 的结果).
func (wm *WaveletMatrixWithSum) KthValueAndSum(start, end, k int32, xorVal WmValue) (WmValue, WmSum) {
	if start >= end {
		return -1, wm.e()
	}
	if k >= end-start {
		return -1, wm.SumAll(start, end)
	}
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	count := int32(0)
	sum := wm.e()
	res := WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		f := (xorVal>>d)&1 == 1
		l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
		c := int32(0)
		if f {
			c = (end - start) - (r0 - l0)
		} else {
			c = r0 - l0
		}
		if count+c > k {
			if !f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		} else {
			var s WmSum
			if f {
				s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)
			} else {
				s = wm._get(d, l0, r0)
			}
			count += c
			sum = wm.op(sum, s)
			res |= 1 << d
			if !f {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			} else {
				start, end = l0, r0
			}
		}
	}
	sum = wm.op(sum, wm._get(0, start, start+k-count))
	if wm.compress {
		res = wm.key[res]
	}
	return res, sum
}

// upper: 向上取中位数还是向下取中位数.
func (wm *WaveletMatrixWithSum) Median(start, end int32, upper bool, xorVal WmValue) WmValue {
	n := end - start
	var k int32
	if upper {
		k = n / 2
	} else {
		k = (n - 1) / 2
	}
	return wm.Kth(start, end, k, xorVal)
}

func (wm *WaveletMatrixWithSum) Sum(start, end, k1, k2 int32, xorVal WmValue) WmSum {
	if k1 >= k2 {
		return wm.e()
	}
	add := wm._prefixSum(start, end, k2, xorVal)
	sub := wm._prefixSum(start, end, k1, xorVal)
	return wm.op(add, wm.inv(sub))
}

func (wm *WaveletMatrixWithSum) SumAll(start, end int32) WmSum {
	return wm._get(wm.log, start, end)
}

// 使得predicate(count, sum)为true的最大的(count, sum).
func (wm *WaveletMatrixWithSum) MaxRight(predicate func(int32, WmSum) bool, start, end int32, xorVal WmValue) (int32, WmSum) {
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if start == end {
		return end - start, wm.e()
	}
	if s := wm._get(wm.log, start, end); predicate(end-start, s) {
		return end - start, s
	}
	count := int32(0)
	sum := wm.e()
	for d := wm.log - 1; d >= 0; d-- {
		f := (xorVal>>d)&1 == 1
		l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
		c := int32(0)
		if f {
			c = (end - start) - (r0 - l0)
		} else {
			c = (r0 - l0)
		}
		var s WmSum
		if f {
			s = wm._get(d, start+wm.mid[d]-l0, end+wm.mid[d]-r0)
		} else {
			s = wm._get(d, l0, r0)
		}
		if tmp := wm.op(sum, s); predicate(count+c, tmp) {
			count += c
			sum = tmp
			if f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		} else {
			if !f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		}
	}
	k := wm.binarySearch(func(k int32) bool {
		return predicate(count+k, wm.op(sum, wm._get(0, start, start+k)))
	}, 0, end-start)
	count += k
	sum = wm.op(sum, wm._get(0, start, start+k))
	return count, sum
}

func (wm *WaveletMatrixWithSum) CountSegments(segments [][2]int32, a, b WmValue, xorVal WmValue) int32 {
	res := int32(0)
	for _, seg := range segments {
		res += wm.Count(seg[0], seg[1], a, b, xorVal)
	}
	return res
}

func (wm *WaveletMatrixWithSum) KthSegments(segments [][2]int32, k int32, xorVal WmValue) WmValue {
	totalLen := int32(0)
	for _, seg := range segments {
		totalLen += seg[1] - seg[0]
	}
	count := int32(0)
	res := WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		f := (xorVal>>d)&1 == 1
		c := int32(0)
		for _, seg := range segments {
			l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
			if f {
				c += (seg[1] - seg[0]) - (r0 - l0)
			} else {
				c += r0 - l0
			}
		}
		if count+c > k {
			for _, seg := range segments {
				l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
				if !f {
					seg[0], seg[1] = l0, r0
				} else {
					seg[0] += wm.mid[d] - l0
					seg[1] += wm.mid[d] - r0
				}
			}
		} else {
			count += c
			res |= 1 << d
			for _, seg := range segments {
				l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
				if f {
					seg[0], seg[1] = l0, r0
				} else {
					seg[0] += wm.mid[d] - l0
					seg[1] += wm.mid[d] - r0
				}
			}
		}
	}
	if wm.compress {
		res = wm.key[res]
	}
	return res
}

func (wm *WaveletMatrixWithSum) KthValueAndSumSegments(segments [][2]int32, k int32, xorVal WmValue) (WmValue, WmSum) {
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	totalLen := int32(0)
	for _, seg := range segments {
		totalLen += seg[1] - seg[0]
	}
	if k >= totalLen {
		return -1, wm.SumAllSegments(segments)
	}
	count := int32(0)
	sum := wm.e()
	res := WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		f := (xorVal>>d)&1 == 1
		c := int32(0)
		for _, seg := range segments {
			l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
			if f {
				c += (seg[1] - seg[0]) - (r0 - l0)
			} else {
				c += r0 - l0
			}
		}
		if count+c > k {
			for _, seg := range segments {
				l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
				if !f {
					seg[0], seg[1] = l0, r0
				} else {
					seg[0] += wm.mid[d] - l0
					seg[1] += wm.mid[d] - r0
				}
			}
		} else {
			count += c
			res |= 1 << d
			for _, seg := range segments {
				l0, r0 := wm.bv[d].Rank(seg[0], false), wm.bv[d].Rank(seg[1], false)
				var s WmSum
				if f {
					s = wm._get(d, seg[0]+wm.mid[d]-l0, seg[1]+wm.mid[d]-r0)
				} else {
					s = wm._get(d, l0, r0)
				}
				sum = wm.op(sum, s)
				if !f {
					seg[0] += wm.mid[d] - l0
					seg[1] += wm.mid[d] - r0
				} else {
					seg[0], seg[1] = l0, r0
				}
			}
		}
	}
	for _, seg := range segments {
		t := min32(seg[1]-seg[0], k-count)
		sum = wm.op(sum, wm._get(0, seg[0], seg[0]+t))
		count += t
	}
	if wm.compress {
		res = wm.key[res]
	}
	return res, sum
}

func (wm *WaveletMatrixWithSum) MedianSegments(segments [][2]int32, upper bool, xorVal WmValue) WmValue {
	n := int32(0)
	for _, seg := range segments {
		n += seg[1] - seg[0]
	}
	var k int32
	if upper {
		k = n / 2
	} else {
		k = (n - 1) / 2
	}
	return wm.KthSegments(segments, k, xorVal)
}

func (wm *WaveletMatrixWithSum) SumAllSegments(segments [][2]int32) WmSum {
	sum := wm.e()
	for _, seg := range segments {
		sum = wm.op(sum, wm._get(wm.log, seg[0], seg[1]))
	}
	return sum
}

// 返回区间 [start, end) 中 范围在 [0, x) 中的元素的个数.
func (wm *WaveletMatrixWithSum) _prefixCount(start, end int32, x WmValue, xor WmValue) int32 {
	if xor != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if wm.compress {
		x = wm.lowerBound(wm.key, x)
	}
	if x == 0 {
		return 0
	}
	if x >= 1<<wm.log {
		return end - start
	}
	count := int32(0)
	for d := wm.log - 1; d >= 0; d-- {
		add := (x>>d)&1 == 1
		f := (xor>>d)&1 == 1
		l0, r0 := wm.bv[d].Rank(start, false), wm.bv[d].Rank(end, false)
		kf := int32(0)
		if f {
			kf = (end - start - r0 + l0)
		} else {
			kf = (r0 - l0)
		}
		if add {
			count += kf
			if f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		} else {
			if !f {
				start, end = l0, r0
			} else {
				start += wm.mid[d] - l0
				end += wm.mid[d] - r0
			}
		}
	}
	return count
}

// 返回区间 [start, end) 中 [0, k) 的和.
func (wm *WaveletMatrixWithSum) _prefixSum(start, end, k int32, xor WmValue) WmSum {
	_, sum := wm.KthValueAndSum(start, end, k, xor)
	return sum
}

func (wm *WaveletMatrixWithSum) _prefixSumSegments(segments [][2]int32, k int32, xor WmValue) WmSum {
	_, sum := wm.KthValueAndSumSegments(segments, k, xor)
	return sum
}

func (wm *WaveletMatrixWithSum) _get(d, l, r int32) WmSum {
	return wm.op(wm.inv(wm.presum[d][l]), wm.presum[d][r])
}

func (wm *WaveletMatrixWithSum) argSort(nums []WmValue) []int32 {
	order := make([]int32, len(nums))
	for i := range order {
		order[i] = int32(i)
	}
	sort.Slice(order, func(i, j int) bool { return nums[order[i]] < nums[order[j]] })
	return order
}

func (wm *WaveletMatrixWithSum) maxs(nums []WmValue) WmValue {
	res := nums[0]
	for _, v := range nums {
		if v > res {
			res = v
		}
	}
	return res
}

func (wm *WaveletMatrixWithSum) lowerBound(nums []WmValue, target WmValue) WmValue {
	left, right := int32(0), int32(len(nums)-1)
	for left <= right {
		mid := (left + right) >> 1
		if nums[mid] < target {
			left = mid + 1
		} else {
			right = mid - 1
		}
	}
	return WmValue(left)
}

func (wm *WaveletMatrixWithSum) binarySearch(f func(int32) bool, ok, ng int32) int32 {
	for abs32(ok-ng) > 1 {
		x := (ok + ng) >> 1
		if f(x) {
			ok = x
		} else {
			ng = x
		}
	}
	return ok
}

func abs32(x int32) int32 {
	if x < 0 {
		return -x
	}
	return x
}

func min32(a, b int32) int32 {
	if a < b {
		return a
	}
	return b
}

// TODO: uint64 or uint32 ??
type BitVector struct {
	bits   []uint64
	preSum []int32
}

func NewBitVector(n int32) *BitVector {
	return &BitVector{bits: make([]uint64, (n+63)>>6), preSum: make([]int32, (n+63)>>6)}
}

func (bv *BitVector) Set(i int32) {
	bv.bits[i>>6] |= 1 << (i & 63)
}

func (bv *BitVector) Build() {
	for i := 0; i < len(bv.bits)-1; i++ {
		bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bits[i]))
	}
}

func (bv *BitVector) Rank(k int32, f bool) int32 {
	m, s := bv.bits[k>>6], bv.preSum[k>>6]
	res := s + int32(bits.OnesCount64(m&((1<<(k&63))-1)))
	if f {
		return res
	}
	return k - res
}
0