結果

問題 No.738 平らな農地
コンテスト
ユーザー 回転
提出日時 2025-11-26 14:22:32
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,187 ms / 2,000 ms
コード長 6,021 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 291 ms
コンパイル使用メモリ 82,648 KB
実行使用メモリ 152,992 KB
最終ジャッジ日時 2025-11-26 14:23:16
合計ジャッジ時間 42,527 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 5
other AC * 87
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

class WaveletMatrix:
    def __init__(self,V):
        self.n=len(V)
        self.lg=max(V).bit_length()
        self.ranks=[]
        self.accs=[]
        self.original_V=V
        V=list(V)
        for bit in range(self.lg-1,-1,-1):
            rank=[0]*(self.n+1)
            for i,v in enumerate(V):
                rank[i+1]=rank[i]+((v>>bit)&1)
                
            swp=[0]*self.n
            zero,one=0,self.n-rank[self.n]
            for v in V:
                if (v>>bit)&1:
                    swp[one]=v
                    one+=1
                else:
                    swp[zero]=v
                    zero+=1
            
            acc=[0]*(self.n+1)
            for i,v in enumerate(swp):
                acc[i+1]=acc[i]+v
            
            V=swp
            self.ranks.append(rank)
            self.accs.append(acc)
        self.accs.append([0])
        for v in self.original_V:
            self.accs[-1].append(self.accs[-1][-1]+v)
    def access(self,i):
        return self.original_V[i]
    def rank(self,r,x):
        if (x>>self.lg)&1:return 0
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[i]
            if bit:
                zeros=self.n-rank[r]
                r=zeros+rank[r]
            else:
                r-=rank[r]
        return r
    def rank_range(self,l,r,x):
        return self.rank(r,x)-self.rank(l,x)
    def quantile(self,l,r,k):
        res=0
        for i in range(self.lg-1,-1,-1):
            rank=self.ranks[self.lg-1-i]
            ones=rank[r]-rank[l]
            zeros=(r-l)-ones
            if k<zeros:
                l-=rank[l]
                r-=rank[r]
            else:
                res|=1<<i
                k-=zeros
                zero_sum=self.n-rank[self.n]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
        return res
    def _range_freq(self,l,r,x):
        if x.bit_length() > self.lg:
            return r - l
        res=0
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[self.lg-1-i]
            ones=rank[r]-rank[l]
            zeros=(r-l)-ones
            if bit:
                res+=zeros
                zero_sum=self.n-rank[self.n]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
            else:
                l-=rank[l]
                r-=rank[r]
        return res
    def range_freq(self,left,right,lower,upper):
        return self._range_freq(left,right,upper)-self._range_freq(left,right,lower)
    def prev_value(self,left,right,upper):
        cnt=self._range_freq(left,right,upper)
        return self.quantile(left,right,cnt-1) if cnt>0 else None
    def next_value(self,left,right,lower):
        cnt=self._range_freq(left,right,lower)
        return self.quantile(left,right,cnt) if cnt<right-left else None
    def _range_sum(self,l,r,x):
        if self.lg<x.bit_length():return self.accs[-1][r]-self.accs[-1][l]
        res=0
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[self.lg-1-i]
            acc=self.accs[self.lg-1-i]
            if bit:
                zero_sum=self.n-rank[self.n]
                l0=l-rank[l]
                r0=r-rank[r]
                res+=acc[r0]-acc[l0]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
            else:
                l-=rank[l]
                r-=rank[r]
        return res
    def range_sum(self,left,right,lower,upper):
        return self._range_sum(left,right,upper)-self._range_sum(left,right,lower)

    def _build_distinct_wm(self):
        P = [0] * self.n
        last_pos = {}
        for i, v in enumerate(self.original_V):
            P[i] = last_pos.get(v, -1) + 1
            last_pos[v] = i
        self._distinct_wm = WaveletMatrix(P)

    def range_distinct(self, left, right):
        """
        区間 [left, right) に含まれる要素の種類数を返す
        初回呼び出し時にO(N log N)で補助構造を構築し、以降はO(log N)で応答する。
        """
        if not hasattr(self, "_distinct_wm"):
            self._build_distinct_wm()
        
        return self._distinct_wm.range_freq(left, right, 0, left + 1)

    def bottom_k_sum(self, l, r, k):
        """
        区間 [l, r) の中で小さい方から k 個の要素の和を返す
        計算量: O(log N)
        """
        if k <= 0: return 0
        if k >= r - l: return self.accs[-1][r] - self.accs[-1][l]
        
        res = 0
        val = 0
        
        for i in range(self.lg - 1, -1, -1):
            rank = self.ranks[self.lg - 1 - i]
            acc = self.accs[self.lg - 1 - i]
            
            ones = rank[r] - rank[l]
            zeros = (r - l) - ones
            
            l0 = l - rank[l]
            r0 = r - rank[r]
            
            if k <= zeros:
                l = l0
                r = r0
            else:
                res += acc[r0] - acc[l0]
                k -= zeros
                
                val |= (1 << i)
                zero_sum = self.n - rank[self.n]
                l = zero_sum + rank[l]
                r = zero_sum + rank[r]
        
        res += k * val
        return res

    def top_k_sum(self, l, r, k):
        """
        区間 [l, r) の中で大きい方から k 個の要素の和を返す
        計算量: O(log N)
        """
        if k <= 0: return 0
        length = r - l
        if k >= length: return self.accs[-1][r] - self.accs[-1][l]
        
        total_sum = self.accs[-1][r] - self.accs[-1][l]
        return total_sum - self.bottom_k_sum(l, r, length - k)

N,K = list(map(int,input().split()))
A = list(map(int,input().split()))
WM = WaveletMatrix(A)

ans = 10**18
for i in range(N-K+1):
    mid = WM.quantile(i,i+K,K//2)
    tmp = 0
    tmp += mid*WM.range_freq(i,i+K,0,mid) - WM.range_sum(i,i+K,0,mid)
    tmp += WM.range_sum(i,i+K,mid,10**18) - mid*WM.range_freq(i,i+K,mid,10**18)

    ans = min(ans,tmp)
print(ans)
0