結果

問題 No.1270 Range Arrange Query
ユーザー lam6er
提出日時 2025-04-09 21:00:52
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,596 bytes
コンパイル時間 236 ms
コンパイル使用メモリ 82,172 KB
実行使用メモリ 262,332 KB
最終ジャッジ日時 2025-04-09 21:01:49
合計ジャッジ時間 22,990 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other WA * 8 TLE * 1 -- * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

class SegmentTree:
    def __init__(self, data, func, default):
        self.n = len(data)
        self.func = func
        self.default = default
        self.size = 1
        while self.size < self.n:
            self.size <<= 1
        self.tree = [default] * (2 * self.size)
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = func(self.tree[2 * i], self.tree[2 * i + 1])
    
    def query(self, l, r):
        res = self.default
        l += self.size
        r += self.size
        while l <= r:
            if l % 2 == 1:
                res = self.func(res, self.tree[l])
                l += 1
            if r % 2 == 0:
                res = self.func(res, self.tree[r])
                r -= 1
            l >>= 1
            r >>= 1
        return res

class FenwickTree:
    def __init__(self, size):
        self.n = size
        self.tree = [0] * (self.n + 2)
    
    def update(self, idx, delta=1):
        while idx <= self.n:
            self.tree[idx] += delta
            idx += idx & -idx
    
    def query(self, idx):
        res = 0
        while idx > 0:
            res += self.tree[idx]
            idx -= idx & -idx
        return res

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, Q = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    a = list(map(int, input[ptr:ptr+N]))
    ptr += N
    queries = []
    for q in range(Q):
        l, r = int(input[ptr]), int(input[ptr+1])
        ptr += 2
        queries.append((l, r, q))
    
    # Preprocess left_inversion
    left_inversion = [0] * (N + 1)
    bit = FenwickTree(N)
    for i in range(N):
        val = a[i]
        count = bit.query(N) - bit.query(val)
        left_inversion[i+1] = left_inversion[i] + count
        bit.update(val)
    
    # Preprocess right_inversion
    right_inversion = [0] * (N + 2)
    bit = FenwickTree(N)
    for i in reversed(range(N)):
        val = a[i]
        count = bit.query(val - 1)
        right_inversion[i+1] = right_inversion[i+2] + count
        bit.update(val)
    
    # Prepare Segment Trees for max and min queries
    max_data = a.copy()
    min_data = a.copy()
    st_max = SegmentTree(max_data, max, -float('inf'))
    st_min = SegmentTree(min_data, min, float('inf'))
    
    # Process queries offline for cross_inversion
    sorted_queries = sorted(queries, key=lambda x: (-x[1], x[0]))
    cross_ans = [0] * Q
    cross_bit = FenwickTree(N)
    current_r = N + 1  # j starts from current_r to N
    
    event_map = {}
    for q in sorted_queries:
        r = q[1]
        if r not in event_map:
            event_map[r] = []
        event_map[r].append(q)
    
    # Process in descending order of r
    for r in range(N, -1, -1):
        # j must be >= r+1 (current_r starts from r+1)
        if r < N:
            j = r + 1
            cross_bit.update(a[j-1], 1)
        if r in event_map:
            for q in event_map[r]:
                l, r_query, q_idx = q
                total = 0
                if l > 1:
                    # To calculate sum of cross_bit.query(ai-1) for ai in a[0..l-2]
                    # Precompute prefix sums
                    # We need to query for each ai in a[0..l-2]
                    # Using another BIT to accumulate
                    temp_bit = FenwickTree(N)
                    prefix = [0] * (l)
                    for i in range(l-1):
                        ai = a[i]
                        cnt = cross_bit.query(ai - 1)
                        total += cnt
                cross_ans[q_idx] = total
    
    result = [0] * Q
    for q in queries:
        l, r, q_idx = q
        Lmax = -1
        L_size = l - 1
        if L_size > 0:
            Lmax = st_max.query(0, l-2)
        else:
            Lmax = -float('inf')
        Rmin = float('inf')
        R_size = N - r
        if R_size > 0:
            Rmin = st_min.query(r, N-1)
        else:
            Rmin = float('inf')
        k = r - l + 1
        
        left_inv = left_inversion[l-1] if l > 0 else 0
        right_inv = right_inversion[r+1] if r < N else 0
        cross_inv = cross_ans[q_idx]
        
        if Lmax <= Rmin:
            min_inv = left_inv + right_inv + cross_inv
        else:
            add = min(L_size * k, R_size * k)
            min_inv = left_inv + right_inv + cross_inv + add
        
        if l == 1 and r == N:
            min_inv = 0
        
        result[q_idx] = min_inv
    
    for ans in result:
        print(ans)

if __name__ == '__main__':
    main()
0