結果

問題 No.1332 Range Nearest Query
ユーザー 草苺奶昔草苺奶昔
提出日時 2024-04-24 20:12:37
言語 Go
(1.22.1)
結果
WA  
実行時間 -
コード長 12,458 bytes
コンパイル時間 13,642 ms
コンパイル使用メモリ 220,876 KB
実行使用メモリ 17,280 KB
最終ジャッジ日時 2024-11-07 05:06:40
合計ジャッジ時間 36,885 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 1 ms
5,248 KB
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 AC 406 ms
16,128 KB
testcase_07 AC 407 ms
16,256 KB
testcase_08 AC 410 ms
16,128 KB
testcase_09 AC 403 ms
17,152 KB
testcase_10 AC 403 ms
16,128 KB
testcase_11 AC 403 ms
16,640 KB
testcase_12 AC 406 ms
16,128 KB
testcase_13 AC 408 ms
17,280 KB
testcase_14 AC 404 ms
16,128 KB
testcase_15 AC 403 ms
16,128 KB
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 AC 359 ms
16,000 KB
testcase_27 AC 370 ms
16,000 KB
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 AC 103 ms
5,248 KB
testcase_32 WA -
testcase_33 WA -
testcase_34 WA -
testcase_35 WA -
testcase_36 WA -
testcase_37 WA -
testcase_38 WA -
testcase_39 WA -
testcase_40 WA -
testcase_41 WA -
testcase_42 WA -
testcase_43 WA -
testcase_44 WA -
testcase_45 WA -
testcase_46 WA -
testcase_47 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

package main

import (
	"bufio"
	"fmt"
	"math"
	"math/bits"
	"os"
	"sort"
)

func main() {

	// test()
	// testTime()
	// yosupo()

	区间前驱后继()
}

// https://judge.yosupo.jp/problem/range_kth_smallest
func yosupo() {
	in := bufio.NewReader(os.Stdin)
	out := bufio.NewWriter(os.Stdout)
	defer out.Flush()

	var n, q int
	fmt.Fscan(in, &n, &q)
	nums := make([]int, n)
	for i := 0; i < n; i++ {
		fmt.Fscan(in, &nums[i])
	}

	newNums, origin := DiscretizeFast(nums)
	wm := NewWaveletMatrixStatic(int32(len(nums)), func(i int32) int { return int(newNums[i]) }, len(origin))

	for i := 0; i < q; i++ {
		var start, end, x int32
		fmt.Fscan(in, &start, &end, &x)
		res := wm.KthSmallest(start, end, x)
		fmt.Fprintln(out, origin[res])
	}
}

func 区间前驱后继() {
	in := bufio.NewReader(os.Stdin)
	out := bufio.NewWriter(os.Stdout)
	defer out.Flush()

	var n int32
	fmt.Fscan(in, &n)
	nums := make([]int, n)
	for i := int32(0); i < n; i++ {
		fmt.Fscan(in, &nums[i])
	}

	newNums, origin := DiscretizeFast(nums)

	wm := NewWaveletMatrixStatic(int32(len(newNums)), func(i int32) int { return int(newNums[i]) }, len(origin))
	var q int32
	fmt.Fscan(in, &q)
	for i := int32(0); i < q; i++ {
		var start, end int32
		var x int
		fmt.Fscan(in, &start, &end, &x)
		start--

		newX := int(BisectLeft(origin, x))

		res := math.MaxInt
		floor, ok := wm.Floor(start, end, newX)
		if ok {
			floor = origin[floor]
			res = min(res, abs(x-floor))
		}
		ceil, ok := wm.Ceil(start, end, newX)
		if ok {
			ceil = origin[ceil]
			res = min(res, abs(x-ceil))
		}
		fmt.Fprintln(out, res)
	}
}

func demo() {
	nums := []int{3, 1, 4, 1, 5, 9, 2, 6}
	wm := NewWaveletMatrixStatic(int32(len(nums)), func(i int32) int { return nums[i] }, maxs(nums)+1)
	fmt.Println(wm.PrefixCount(10, 2))
	fmt.Println(wm.Kth(0, 2))
	fmt.Println(wm.KthSmallest(1, 4, 2))
	fmt.Println(wm.RangeFreq(0, 8, 1, 4))
	fmt.Println(wm.RangeCount(0, 8, 1))
	fmt.Println(wm.RangeCount(0, 8, 2))
	fmt.Println(wm.Floor(0, 8, 3))
	fmt.Println(wm.Ceil(0, 8, 3))
	fmt.Println(wm.Lower(0, 8, 3))
	fmt.Println(wm.Higher(0, 8, 3))
}

// 将nums中的元素进行离散化,返回新的数组和对应的原始值.
// origin[newNums[i]] == nums[i]
func DiscretizeFast(nums []int) (newNums []int32, origin []int) {
	newNums = make([]int32, len(nums))
	origin = make([]int, 0, len(newNums))
	order := argSort(int32(len(nums)), func(i, j int32) bool { return nums[i] < nums[j] })
	for _, i := range order {
		if len(origin) == 0 || origin[len(origin)-1] != nums[i] {
			origin = append(origin, nums[i])
		}
		newNums[i] = int32(len(origin) - 1)
	}
	origin = origin[:len(origin):len(origin)]
	return
}

func BisectLeft(nums []int, target int) int32 {
	left, right := int32(0), int32(len(nums)-1)
	for left <= right {
		mid := (left + right) >> 1
		if nums[mid] < target {
			left = mid + 1
		} else {
			right = mid - 1
		}
	}
	return left
}

func argSort(n int32, less func(i, j int32) bool) []int32 {
	order := make([]int32, n)
	for i := range order {
		order[i] = int32(i)
	}
	sort.Slice(order, func(i, j int) bool { return less(order[i], order[j]) })
	return order
}

func maxs(nums []int) int {
	max := nums[0]
	for _, num := range nums {
		if num > max {
			max = num
		}
	}
	return max
}

func maxs32(nums []int32) int32 {
	max := nums[0]
	for _, num := range nums {
		if num > max {
			max = num
		}
	}
	return max
}

// 维护[0,maxValue].
type WaveletMatrixStatic struct {
	size     int32
	maxValue int
	bitLen   int32
	mid      []int32
	bv       []*bitVector
}

type topKPair = struct {
	value int
	count int32
}

func NewWaveletMatrixStatic(n int32, f func(int32) int, maxValue int) *WaveletMatrixStatic {
	if maxValue <= 0 {
		maxValue = 1
	}
	res := &WaveletMatrixStatic{
		size:     n,
		maxValue: maxValue,
		bitLen:   int32(bits.Len(uint(maxValue))),
	}
	res.mid = make([]int32, res.bitLen)
	res.bv = make([]*bitVector, res.bitLen)
	if n > 0 {
		res._build(n, f)
	}
	return res
}

func (wm *WaveletMatrixStatic) PrefixCount(end int32, x int) int32 {
	if end > wm.size {
		end = wm.size
	}
	if end <= 0 {
		return 0
	}
	start := int32(0)
	mid := wm.mid
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		if x>>bit&1 == 1 {
			start = wm.bv[bit].Count1(start) + mid[bit]
			end = wm.bv[bit].Count1(end) + mid[bit]
		} else {
			start = wm.bv[bit].Count0(start)
			end = wm.bv[bit].Count0(end)
		}
	}
	return end - start
}

func (wm *WaveletMatrixStatic) RangeCount(start, end int32, x int) int32 {
	if start < 0 {
		start = 0
	}
	if end > wm.size {
		end = wm.size
	}
	if start >= end {
		return 0
	}
	same, _, _ := wm.CountAll(start, end, x)
	return same
}

func (wm *WaveletMatrixStatic) RangeFreq(start, end int32, floor, higher int) int32 {
	if floor >= higher {
		return 0
	}
	if start < 0 {
		start = 0
	}
	if end > wm.size {
		end = wm.size
	}
	if start >= end {
		return 0
	}
	return wm.CountLess(start, end, higher) - wm.CountLess(start, end, floor)
}

// 返回第k个x所在的位置.
func (wm *WaveletMatrixStatic) Kth(k int32, x int) int32 {
	s := int32(0)
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		if x>>bit&1 == 1 {
			s = wm.bv[bit].Count0(wm.size) + wm.bv[bit].Count1(s)
		} else {
			s = wm.bv[bit].Count0(s)
		}
	}
	s += k
	for bit := int32(0); bit < wm.bitLen; bit++ {
		if x>>bit&1 == 1 {
			s = wm.bv[bit].Kth1(s - wm.bv[bit].Count0(wm.size))
		} else {
			s = wm.bv[bit].Kth0(s)
		}
	}
	return s
}

func (wm *WaveletMatrixStatic) KthSmallest(start, end int32, k int32) int {
	if k < 0 || k >= end-start {
		return -1
	}
	res := 0
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		l0, r0 := wm.bv[bit].Count0(start), wm.bv[bit].Count0(end)
		if c := r0 - l0; c <= k {
			res |= 1 << bit
			k -= c
			start += wm.mid[bit] - l0
			end += wm.mid[bit] - r0
		} else {
			start, end = l0, r0
		}
	}
	return res
}

func (wm *WaveletMatrixStatic) KthSmallestIndex(start, end int32, k int32) int32 {
	if k < 0 || k >= end-start {
		return -1
	}
	val := 0
	for i := wm.bitLen - 1; i >= 0; i-- {
		numOfZeroBegin := wm.bv[i].Count0(start)
		numOfZeroEnd := wm.bv[i].Count0(end)
		numOfZero := numOfZeroEnd - numOfZeroBegin
		bit := 0
		if k >= numOfZero {
			bit = 1
		}
		if bit == 1 {
			k -= numOfZero
			start = wm.mid[i] + start - numOfZeroBegin
			end = wm.mid[i] + end - numOfZeroEnd
		} else {
			start = numOfZeroBegin
			end = numOfZeroBegin + numOfZero
		}
		val = (val << 1) | bit
	}

	left := int32(0)
	for i := wm.bitLen - 1; i >= 0; i-- {
		bit := int8((val >> i) & 1)
		left = wm.bv[i].Count(left, bit)
		if bit == 1 {
			left += wm.mid[i]
		}
	}
	rank := start + k - left
	return wm.Kth(rank, val)
}

func (wm *WaveletMatrixStatic) KthLargest(start, end int32, k int32) int {
	return wm.KthSmallest(start, end, end-start-k-1)
}
func (wm *WaveletMatrixStatic) KthLargestIndex(start, end int32, k int32) int32 {
	return wm.KthSmallestIndex(start, end, end-start-k-1)
}

func (wm *WaveletMatrixStatic) Floor(start, end int32, x int) (int, bool) {
	same, less, _ := wm.CountAll(start, end, x)
	if same > 0 {
		return x, true
	}
	if less == 0 {
		return -1, false
	}
	return wm.KthSmallest(start, end, less-1), true
}
func (wm *WaveletMatrixStatic) Lower(start, end int32, x int) (int, bool) {
	less := wm.CountLess(start, end, x)
	if less == 0 {
		return -1, false
	}
	return wm.KthSmallest(start, end, less-1), true
}

func (wm *WaveletMatrixStatic) Ceil(start, end int32, x int) (int, bool) {
	same, less, _ := wm.CountAll(start, end, x)
	if same > 0 {
		return x, true
	}
	if less == end-start {
		return -1, false
	}
	return wm.KthSmallest(start, end, less), true
}

func (wm *WaveletMatrixStatic) Higher(start, end int32, x int) (int, bool) {
	less := wm.CountLess(start, end, x+1)
	if less == end-start {
		return -1, false
	}
	return wm.KthSmallest(start, end, less), true
}

// 返回[start, end)中等于x的个数,小于x的个数,大于x的个数.
func (wm *WaveletMatrixStatic) CountAll(start, end int32, x int) (same, less, more int32) {
	if start < 0 {
		start = 0
	}
	if end > wm.size {
		end = wm.size
	}
	if start >= end {
		return 0, 0, 0
	}
	if x > wm.maxValue {
		return 0, end - start, 0
	}
	num := end - start
	for i := wm.bitLen - 1; i >= 0 && start < end; i-- {
		bit := x >> i & 1
		rank0Begin := wm.bv[i].Count0(start)
		rank0End := wm.bv[i].Count0(end)
		rank1Begin := start - rank0Begin
		rank1End := end - rank0End
		if bit == 1 {
			less += rank0End - rank0Begin
			start = wm.mid[i] + rank1Begin
			end = wm.mid[i] + rank1End
		} else {
			more += rank1End - rank1Begin
			start = rank0Begin
			end = rank0End
		}
	}
	same = num - less - more
	return
}

func (wm *WaveletMatrixStatic) Get(index int32) int {
	if index < 0 {
		index += wm.size
	}
	res := 0
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		if wm.bv[bit].Get(index) {
			res |= 1 << bit
			index = wm.bv[bit].Count1(index) + wm.mid[bit]
		} else {
			index = wm.bv[bit].Count0(index)
		}
	}
	return res
}

// 区间[start, end)中小于x的元素个数.
func (wm *WaveletMatrixStatic) CountLess(start, end int32, x int) int32 {
	res := int32(0)
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		l0, r0 := wm.bv[bit].Count0(start), wm.bv[bit].Count0(end)
		if x>>bit&1 == 1 {
			res += r0 - l0
			start += wm.mid[bit] - l0
			end += wm.mid[bit] - r0
		} else {
			start, end = l0, r0
		}
	}
	return res
}

// 区间[start, end)中大于x的元素个数.
func (wm *WaveletMatrixStatic) CountMore(start, end int32, x int) int32 {
	return end - start - wm.CountLess(start, end, x+1)
}

func (wm *WaveletMatrixStatic) CountSame(start, end int32, x int) int32 {
	same, _, _ := wm.CountAll(start, end, x)
	return same
}

func (wm *WaveletMatrixStatic) Len() int32 {
	return wm.size
}

func (wm *WaveletMatrixStatic) _build(n int32, f func(int32) int) {
	data := make([]int, n)
	for i := int32(0); i < n; i++ {
		data[i] = f(i)
	}
	zero, one := make([]int, n), make([]int, n)
	for bit := wm.bitLen - 1; bit >= 0; bit-- {
		wm.bv[bit] = newBitVector(n)
		v := wm.bv[bit]
		p, q := int32(0), int32(0)
		for i, e := range data {
			if e>>bit&1 == 1 {
				v.Set(int32(i))
				one[q] = e
				q++
			} else {
				zero[p] = e
				p++
			}
		}
		v.Build()
		wm.mid[bit] = p
		zero, data = data, zero
		copy(data[p:], one[:q])
	}
}

type bitVector struct {
	n      int32
	size   int32
	bit    []uint64
	preSum []int32
}

func newBitVector(n int32) *bitVector {
	size := (n + 63) >> 6
	bit := make([]uint64, size+1)
	preSum := make([]int32, size+1)
	return &bitVector{n: n, size: size, bit: bit, preSum: preSum}
}

func (bv *bitVector) Set(i int32) {
	bv.bit[i>>6] |= 1 << (i & 63)
}

func (bv *bitVector) Build() {
	for i := int32(0); i < bv.size; i++ {
		bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bit[i]))
	}
}

func (bv *bitVector) Get(i int32) bool {
	return bv.bit[i>>6]>>(i&63)&1 == 1
}

func (bv *bitVector) Count0(end int32) int32 {
	return end - (bv.preSum[end>>6] + int32(bits.OnesCount64(bv.bit[end>>6]&(1<<(end&63)-1))))
}

func (bv *bitVector) Count1(end int32) int32 {
	return bv.preSum[end>>6] + int32(bits.OnesCount64(bv.bit[end>>6]&(1<<(end&63)-1)))
}

func (bv *bitVector) Count(end int32, value int8) int32 {
	if value == 1 {
		return bv.Count1(end)
	}
	return end - bv.Count1(end)
}

func (bv *bitVector) Kth0(k int32) int32 {
	if k < 0 || bv.Count0(bv.n) <= k {
		return -1
	}
	l, r := int32(0), bv.size+1
	for r-l > 1 {
		m := (l + r) >> 1
		if m<<6-bv.preSum[m] > k {
			r = m
		} else {
			l = m
		}
	}
	indx := l << 6
	k -= (l<<6 - bv.preSum[l]) - bv.Count0(indx)
	l, r = indx, indx+64
	for r-l > 1 {
		m := (l + r) >> 1
		if bv.Count0(m) > k {
			r = m
		} else {
			l = m
		}
	}
	return l
}

// k>=0.
func (bv *bitVector) Kth1(k int32) int32 {
	if k < 0 || bv.Count1(bv.n) <= k {
		return -1
	}
	l, r := int32(0), bv.size+1
	for r-l > 1 {
		m := (l + r) >> 1
		if bv.preSum[m] > k {
			r = m
		} else {
			l = m
		}
	}
	indx := l << 6
	k -= bv.preSum[l] - bv.Count1(indx)
	l, r = indx, indx+64
	for r-l > 1 {
		m := (l + r) >> 1
		if bv.Count1(m) > k {
			r = m
		} else {
			l = m
		}
	}
	return l
}

func (bv *bitVector) Kth(k int32, v int8) int32 {
	if v == 1 {
		return bv.Kth1(k)
	}
	return bv.Kth0(k)
}

func (bv *bitVector) GetAll() []bool {
	res := make([]bool, 0, bv.n)
	for i := int32(0); i < bv.n; i++ {
		res = append(res, bv.Get(i))
	}
	return res
}

func abs(x int) int {
	if x < 0 {
		return -x
	}
	return x
}

func min(a, b int) int {
	if a < b {
		return a
	}
	return b
}

func max(a, b int) int {
	if a > b {
		return a
	}
	return b
}
0