結果

問題 No.738 平らな農地
コンテスト
ユーザー dokukuma
提出日時 2026-03-20 00:06:15
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 609 ms / 2,000 ms
コード長 9,143 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,130 ms
コンパイル使用メモリ 85,272 KB
実行使用メモリ 250,016 KB
最終ジャッジ日時 2026-03-20 00:06:48
合計ジャッジ時間 30,124 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 5
other AC * 87
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
import itertools
from typing import List

class BitVector:
    """簡潔ビットベクトル (最速ビット演算最適化版)"""
    def __init__(self, n: int):
        self.n = n
        self.block = (n >> 5) + 1
        self.bit = [0] * self.block
        self.cnt = [0] * self.block
        self.zeros = 0

    def set(self, i: int) -> None:
        self.bit[i >> 5] |= 1 << (i & 31)

    def get(self, i: int) -> int:
        return (self.bit[i >> 5] >> (i & 31)) & 1

    def build(self) -> None:
        for i in range(self.block - 1):
            self.cnt[i + 1] = self.cnt[i] + self.popcount(self.bit[i])
        self.zeros = self.rank0(self.n)

    def popcount(self, x: int) -> int:
        x = x - ((x >> 1) & 0x55555555)
        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
        x = (x + (x >> 4)) & 0x0f0f0f0f
        x = x + (x >> 8)
        x = x + (x >> 16)
        return x & 0x0000007f

    def rank1(self, i: int) -> int:
        mask = (1 << (i & 31)) - 1
        return self.cnt[i >> 5] + self.popcount(self.bit[i >> 5] & mask)

    def rank0(self, i: int) -> int:
        return i - self.rank1(i)


class WaveletMatrix:

    def __init__(self, A: List[int]):
        self.n = max(len(A), 1)
        
        # 値クエリ用の木 (累積和も同時に構築する)
        self.lg_v = max(A + [1]).bit_length()
        self.b = [BitVector(self.n) for _ in range(self.lg_v)]
        self.sums = [None] * (self.lg_v + 1)
        self._build(A, self.b, self.lg_v, is_value_tree=True)

        # 種類数クエリ用の木
        prev_i = [0] * len(A)
        last_pos = {}
        for i, a in enumerate(A):
            prev_i[i] = last_pos.get(a, -1) + 1
            last_pos[a] = i
            
        self.lg_cnt = max(prev_i + [1]).bit_length()
        self.cnt = [BitVector(self.n) for _ in range(self.lg_cnt)]
        self._build(prev_i, self.cnt, self.lg_cnt, is_value_tree=False)

    def _build(self, arr: List[int], bv_list: List[BitVector], lg: int, is_value_tree: bool) -> None:
        if is_value_tree:
            # 構築前の元配列の累積和を最上層(lg_v)に保持
            self.sums[lg] = [0] + list(itertools.accumulate(arr))

        cur = arr[:]
        for h in range(lg - 1, -1, -1):
            bv = bv_list[h]
            bv_set = bv.set
            
            ls, rs = [], []
            ls_append, rs_append = ls.append, rs.append

            for i, val in enumerate(cur):
                if (val >> h) & 1:
                    bv_set(i)
                    rs_append(val)
                else:
                    ls_append(val)
            
            bv.build()
            cur = ls + rs
            
            if is_value_tree:
                self.sums[h] = [0] + list(itertools.accumulate(cur))

    # ====================================================
    # 1. 取得・検索系 (Access / Kth)
    # ====================================================
    def access(self, k: int) -> int:
        """k 番目 (0-indexed) の要素の値を返す"""
        ret = 0
        for h in range(self.lg_v - 1, -1, -1):
            f = self.b[h].get(k)
            if f:
                ret |= 1 << h
                k = self.b[h].rank1(k) + self.b[h].zeros
            else:
                k = self.b[h].rank0(k)
        return ret

    def kth_smallest(self, l: int, r: int, k: int) -> int:
        """区間 [l, r) の範囲で k 番目 (0-indexed) に小さい値を返す"""
        res = 0
        for h in range(self.lg_v - 1, -1, -1):
            l0 = self.b[h].rank0(l)
            r0 = self.b[h].rank0(r)
            if k < r0 - l0:
                l, r = l0, r0
            else:
                k -= r0 - l0
                res |= 1 << h
                l += self.b[h].zeros - l0
                r += self.b[h].zeros - r0
        return res

    def kth_largest(self, l: int, r: int, k: int) -> int:
        """区間 [l, r) の範囲で k 番目 (0-indexed) に大きい値を返す"""
        return self.kth_smallest(l, r, r - l - 1 - k)

    # ====================================================
    # 2. カウント・個数系 (Frequency / Count)
    # ====================================================
    def range_freq(self, l: int, r: int, upper: int) -> int:
        """区間 [l, r) の範囲で upper 未満の要素の個数を返す"""
        if upper >= (1 << self.lg_v):
            return r - l
        ret = 0
        for h in range(self.lg_v - 1, -1, -1):
            f = (upper >> h) & 1
            l0 = self.b[h].rank0(l)
            r0 = self.b[h].rank0(r)
            if f:
                ret += r0 - l0
                l += self.b[h].zeros - l0
                r += self.b[h].zeros - r0
            else:
                l, r = l0, r0
        return ret

    def range_freq_range(self, l: int, r: int, lower: int, upper: int) -> int:
        """区間 [l, r) の範囲で lower 以上 upper 未満の要素の個数を返す"""
        if lower >= upper: return 0
        return self.range_freq(l, r, upper) - self.range_freq(l, r, lower)

    def range_count(self, l: int, r: int) -> int:
        """区間 [l, r) の種類数 (Count Distinct) を返す"""
        upper = l + 1
        if upper >= (1 << self.lg_cnt):
            return r - l
        ret = 0
        for h in range(self.lg_cnt - 1, -1, -1):
            f = (upper >> h) & 1
            l0 = self.cnt[h].rank0(l)
            r0 = self.cnt[h].rank0(r)
            if f:
                ret += r0 - l0
                l += self.cnt[h].zeros - l0
                r += self.cnt[h].zeros - r0
            else:
                l, r = l0, r0
        return ret

    # ====================================================
    # 3. 前後探索系 (Predecessor / Successor)
    # ====================================================
    def prev_value(self, l: int, r: int, upper: int) -> int:
        """区間 [l, r) の範囲で upper 未満の最大の値を返す (存在しない場合は -1)"""
        cnt = self.range_freq(l, r, upper)
        return -1 if cnt == 0 else self.kth_smallest(l, r, cnt - 1)

    def next_value(self, l: int, r: int, lower: int) -> int:
        """区間 [l, r) の範囲で lower 以上の最小の値を返す (存在しない場合は -1)"""
        cnt = self.range_freq(l, r, lower)
        return -1 if cnt == r - l else self.kth_smallest(l, r, cnt)

    # ====================================================
    # 4. 総和系 (Sum) 
    # ====================================================
    def kth_sum(self, l: int, r: int, k: int) -> int:
        """区間 [l, r) を昇順ソートしたときの、先頭からk個の要素の「総和」を返す"""
        if k <= 0: return 0
        k = min(k, r - l)
        res_sum = 0
        val = 0
        for h in range(self.lg_v - 1, -1, -1):
            l0 = self.b[h].rank0(l)
            r0 = self.b[h].rank0(r)
            num0 = r0 - l0
            if num0 < k:
                res_sum += self.sums[h][r0] - self.sums[h][l0]
                k -= num0
                val |= 1 << h
                l += self.b[h].zeros - l0
                r += self.b[h].zeros - r0
            else:
                l, r = l0, r0
        # 最後に端数となったk個分の値を足す
        res_sum += k * val
        return res_sum

    def range_sum(self, l: int, r: int, mink: int, maxk: int) -> int:
        """区間 [l, r) の中で、値が mink 以上 maxk 未満である要素の「総和」を返す"""
        if mink >= maxk or l >= r:
            return 0
            
        def range_sum_upper(end_val: int) -> int:
            """区間内の end_val 未満の総和を計算する内部関数"""
            if end_val >= (1 << self.lg_v):
                return self.sums[self.lg_v][r] - self.sums[self.lg_v][l]
            res = 0
            cur_l, cur_r = l, r
            for h in range(self.lg_v - 1, -1, -1):
                f = (end_val >> h) & 1
                l0 = self.b[h].rank0(cur_l)
                r0 = self.b[h].rank0(cur_r)
                if f:
                    # 注目ビットが1なら、0のグループはすべて end_val 未満になるので総和を足す
                    res += self.sums[h][r0] - self.sums[h][l0]
                    cur_l += self.b[h].zeros - l0
                    cur_r += self.b[h].zeros - r0
                else:
                    cur_l, cur_r = l0, r0
            return res

        return range_sum_upper(maxk) - range_sum_upper(mink)
    
\
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 61 - 1


N, K = mi()
A = li()
WM = WaveletMatrix(A)
ans = inf
pre = [0]
for i in range(N):
    pre.append(pre[-1] + A[i])

for i in range(N-K+1):
    # i ~ i + Kまで、mid(A[i:i+K])に統一するときのコストを計算
    mid = WM.kth_smallest(i, i + K, K//2)
    lower_sum = WM.kth_sum(i, i + K, K // 2)
    upper_sum = pre[i+K] - pre[i] - lower_sum
    value = mid * (K // 2) - lower_sum + upper_sum - mid * (K - K // 2)
    ans = min(ans, value)

print(ans)


    
0