結果
| 問題 | No.738 平らな農地 |
| コンテスト | |
| ユーザー |
terasa
|
| 提出日時 | 2023-01-09 13:20:29 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,125 ms / 2,000 ms |
| コード長 | 9,117 bytes |
| 記録 | |
| コンパイル時間 | 325 ms |
| コンパイル使用メモリ | 82,380 KB |
| 実行使用メモリ | 174,992 KB |
| 最終ジャッジ日時 | 2024-12-17 17:22:56 |
| 合計ジャッジ時間 | 46,178 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 5 |
| other | AC * 87 |
ソースコード
from typing import List, Tuple, Callable, TypeVar, Optional
import sys
import itertools
import heapq
import bisect
import math
from collections import deque, defaultdict
from functools import lru_cache, cmp_to_key
input = sys.stdin.readline
if __file__ != 'prog.py':
sys.setrecursionlimit(10 ** 6)
def readints(): return map(int, input().split())
def readlist(): return list(readints())
def readstr(): return input()[:-1]
def readlist1(): return list(map(lambda x: int(x) - 1, input().split()))
class BitVector:
# reference: https://tiramister.net/blog/posts/bitvector/
TABLE = bytes([
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
])
def __init__(self, N):
self.cnum = (N + 255) >> 8
self.bit = bytearray(self.cnum << 5)
self.chunk = [0] * (self.cnum + 1)
self.blocks = bytearray(self.cnum << 5)
self.built = False
def set(self, pos):
self.bit[pos >> 3] |= 1 << (pos & 7)
def access(self, pos):
return self.bit[pos >> 3] >> (pos & 7) & 1
def popcount(self, num):
return self.TABLE[num]
def build(self):
for i in range(self.cnum):
k = i << 5
for j in range(31):
self.blocks[k + 1] = self.blocks[k] + self.popcount(self.bit[k])
k += 1
self.chunk[i + 1] = self.chunk[i] + self.blocks[k] + self.popcount(self.bit[k])
self.built = True
def rank(self, pos):
assert self.built
cpos, tmp = pos >> 8, pos & 255
bpos, offset = tmp >> 3, tmp & 7
i = cpos << 5 | bpos
rest = self.bit[i] & ((1 << offset) - 1)
return self.chunk[cpos] + self.blocks[i] + self.popcount(rest)
def select(self, num):
"""return minimum i that satisfies rank(i) = num"""
assert self.built
if num == 0:
return 0
if self.rank(self.N) < num:
return -1
l = -1
r = self.N
while r - l > 1:
c = (l + r) >> 1
if self.rank(c) >= num:
r = c
else:
l = c
return r
class WaveletMatrix:
# reference: https://miti-7.hatenablog.com/entry/2018/04/28/152259
def __init__(self, A):
self.nums = sorted(set(A))
self.idx = {a: i for i, a in enumerate(self.nums)}
self.A = [self.idx[a] for a in A]
self.digit = (len(self.nums) - 1).bit_length()
self.B = [None] * self.digit
self.offset = [None] * self.digit
self.start_index = [-1] * len(self.nums)
self.S = [[0 for _ in range(len(self.A) + 1)] for _ in range(self.digit + 1)]
for i, a in enumerate(self.A):
self.S[self.digit][i + 1] = self.S[self.digit][i] + self.nums[a]
T = self.A
for k in range(self.digit)[::-1]:
self.B[k] = BitVector(len(T) + 1)
zeros = []
ones = []
for i, a in enumerate(T):
if a >> k & 1:
self.B[k].set(i)
ones.append(a)
else:
zeros.append(a)
self.B[k].build()
self.offset[k] = len(zeros)
T = zeros + ones
for i, a in enumerate(T):
self.S[k][i + 1] = self.S[k][i] + self.nums[a]
for i, a in enumerate(T):
if self.start_index[a] < 0:
self.start_index[a] = i
def access(self, i):
"""return i-th value"""
ret = 0
cur = i
for k in range(self.digit)[::-1]:
if self.B[k].access(cur):
ret |= (1 << k)
cur = self.B[k].rank(cur) + self.offset[k]
else:
cur -= self.B[k].rank(cur)
return self.nums[ret]
def rank(self, i, x):
"""return the number of x's in [0, i) range"""
x = self.idx.get(x)
if x is None:
return 0
for k in range(self.digit)[::-1]:
if x >> k & 1:
i = self.B[k].rank(i) + self.offset[k]
else:
i -= self.B[k].rank(i)
return i - self.start_index[x]
def quantile(self, l, r, n):
"""return n-th (0-indexed) smallest value in [l, r) range"""
assert 0 <= n < r - l
ret = 0
for k in range(self.digit)[::-1]:
rank_l = self.B[k].rank(l)
rank_r = self.B[k].rank(r)
ones = rank_r - rank_l
zeros = r - l - ones
if zeros <= n:
ret |= 1 << k
l = rank_l + self.offset[k]
r = rank_r + self.offset[k]
n -= zeros
else:
l -= rank_l
r -= rank_r
return self.nums[ret]
def range_freq(self, l, r, lower, upper):
"""return the number of values s.t. lower <= x < upper"""
return self._range_freq_upper(l, r, upper) - self._range_freq_upper(l, r, lower)
def prev_value(self, l, r, upper):
"""return maximum x s.t. x < upper in [l, r) range if exist, otherwise None"""
cnt = self._range_freq_upper(l, r, upper)
if cnt == 0:
return None
return self.quantile(l, r, cnt - 1)
def next_value(self, l, r, lower):
"""return minimum x s.t. x >= lower in [l, r) range if exist, otherwise None"""
cnt = self._range_freq_upper(l, r, lower)
if cnt == r - l:
return None
return self.quantile(l, r, cnt)
def range_sum(self, l, r, lower, upper):
return self._range_sum_upper(l, r, upper) - self._range_sum_upper(l, r, lower)
def range_sum_topn(self, l, r, n):
assert 0 <= n < r - l
if self.digit == 0:
return self.nums[0] * (n + 1)
ret = 0
for k in range(self.digit)[::-1]:
rank_l = self.B[k].rank(l)
rank_r = self.B[k].rank(r)
ones = rank_r - rank_l
zeros = r - l - ones
if zeros <= n:
ret += self.S[k][r - rank_r] - self.S[k][l - rank_l]
l = rank_l + self.offset[k]
r = rank_r + self.offset[k]
n -= zeros
else:
l -= rank_l
r -= rank_r
ret += self.S[0][l + n + 1] - self.S[0][l]
return ret
def _range_freq_upper(self, l, r, upper):
"""return the number of values s.t. x < upper in [l, r) range"""
if l >= r:
return 0
if upper > self.nums[-1]:
return r - l
if upper <= self.nums[0]:
return 0
upper = bisect.bisect_left(self.nums, upper)
ret = 0
for k in range(self.digit)[::-1]:
rank_l = self.B[k].rank(l)
rank_r = self.B[k].rank(r)
ones = rank_r - rank_l
zeros = r - l - ones
if upper >> k & 1:
ret += zeros
l = rank_l + self.offset[k]
r = rank_r + self.offset[k]
else:
l -= rank_l
r -= rank_r
return ret
def _range_sum_upper(self, l, r, upper):
if l >= r:
return 0
if upper > self.nums[-1]:
return self.S[self.digit][r] - self.S[self.digit][l]
if upper <= self.nums[0]:
return 0
upper = bisect.bisect_left(self.nums, upper)
ret = 0
for k in range(self.digit)[::-1]:
rank_l = self.B[k].rank(l)
rank_r = self.B[k].rank(r)
ones = rank_r - rank_l
zero = r - l - ones
if upper >> k & 1:
ret += self.S[k][r - rank_r] - self.S[k][l - rank_l]
l = rank_l + self.offset[k]
r = rank_r + self.offset[k]
else:
l -= rank_l
r -= rank_r
return ret
N, K = readints()
A = readlist()
W = WaveletMatrix(A)
INF = 1 << 50
ans = INF
for i in range(N - K + 1):
med = W.quantile(i, i + K, K // 2)
upper_sum = W.range_sum(i, i + K, med, INF) - med * W.range_freq(i, i + K, med, INF)
lower_sum = med * W.range_freq(i, i + K, 0, med) - W.range_sum(i, i + K, 0, med)
ans = min(ans, upper_sum + lower_sum)
print(ans)
terasa