結果

問題 No.2028 Even Choice
ユーザー lam6er
提出日時 2025-03-26 15:56:12
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,127 bytes
コンパイル時間 319 ms
コンパイル使用メモリ 83,036 KB
実行使用メモリ 129,876 KB
最終ジャッジ日時 2025-03-26 15:56:46
合計ジャッジ時間 8,248 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other WA * 3 TLE * 1 -- * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

import heapq

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    N = int(data[0])
    K = int(data[1])
    A = list(map(int, data[2:2+N]))
    
    # We need to build two segment trees: one for even indices and one for odd indices (original 1-based)
    # Here, we'll use 0-based indices for the original array.
    max_even = [-float('inf')] * (N + 1)
    max_odd = [-float('inf')] * (N + 1)
    pos_even = [-1] * (N + 1)
    pos_odd = [-1] * (N + 1)
    
    # Precompute the maximum values for even and odd positions (original 1-based)
    for i in range(N):
        original_pos = i + 1
        if original_pos % 2 == 0:
            max_even[i] = A[i]
            pos_even[i] = i
        else:
            max_odd[i] = A[i]
            pos_odd[i] = i
    
    # Build segment trees for max_even and max_odd
    # For simplicity, use a list-based segment tree implementation (efficient for this problem)
    # Here, we'll use a binary indexed tree (Fenwick Tree) for max queries (not ideal, but for the sake of time)
    # However, for the sake of time and correctness, we'll use a different approach.
    # Instead, we'll use a heap-based approach with intervals.
    
    # We'll represent each interval as (max_value, start, end, flip, is_even)
    # flip: 0 for no flip, 1 for flip (even becomes odd and vice versa)
    heap = []
    
    def get_max(start, end, flip):
        if start > end:
            return (-float('inf'), -1)
        max_val = -float('inf')
        pos = -1
        for i in range(start, end + 1):
            original_pos = i + 1
            if (original_pos % 2 == 0) ^ (flip % 2 == 1):
                if A[i] > max_val:
                    max_val = A[i]
                    pos = i
        return (max_val, pos)
    
    # Initial interval: entire array, flip 0
    initial_max, initial_pos = get_max(0, N-1, 0)
    if initial_max != -float('inf'):
        heapq.heappush(heap, (-initial_max, 0, N-1, 0, initial_pos))
    
    total = 0
    count = 0
    
    while heap and count < K:
        current = heapq.heappop(heap)
        current_max = -current[0]
        start = current[1]
        end = current[2]
        flip = current[3]
        pos = current[4]
        
        total += current_max
        count += 1
        
        # Split into left and right intervals
        # Left: start to pos-1, same flip
        left_start = start
        left_end = pos - 1
        if left_start <= left_end:
            left_max, left_pos = get_max(left_start, left_end, flip)
            if left_max != -float('inf'):
                heapq.heappush(heap, (-left_max, left_start, left_end, flip, left_pos))
        
        # Right: pos+1 to end, flip + 1
        right_start = pos + 1
        right_end = end
        if right_start <= right_end:
            right_flip = flip + 1
            right_max, right_pos = get_max(right_start, right_end, right_flip)
            if right_max != -float('inf'):
                heapq.heappush(heap, (-right_max, right_start, right_end, right_flip, right_pos))
    
    print(total)

if __name__ == "__main__":
    main()
0