結果

問題 No.2028 Even Choice
ユーザー gew1fw
提出日時 2025-06-12 21:38:04
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,114 bytes
コンパイル時間 201 ms
コンパイル使用メモリ 82,252 KB
実行使用メモリ 170,736 KB
最終ジャッジ日時 2025-06-12 21:41:36
合計ジャッジ時間 6,848 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other TLE * 1 -- * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    import sys
    sys.setrecursionlimit(1 << 25)
    N, K = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    
    # Segment tree node: (max_even, max_odd)
    class Node:
        __slots__ = ['l', 'r', 'left', 'right', 'max_even', 'max_odd', 'flip']
        def __init__(self, l, r):
            self.l = l
            self.r = r
            self.left = None
            self.right = None
            self.max_even = -float('inf')
            self.max_odd = -float('inf')
            self.flip = False  # To track if the parity is flipped
    
    def build(l, r):
        node = Node(l, r)
        if l == r:
            val = A[l-1]  # 1-based
            parity = (l) % 2
            if parity == 0:
                node.max_even = val
                node.max_odd = -float('inf')
            else:
                node.max_odd = val
                node.max_even = -float('inf')
            return node
        mid = (l + r) // 2
        node.left = build(l, mid)
        node.right = build(mid+1, r)
        node.max_even = max(node.left.max_even, node.right.max_even)
        node.max_odd = max(node.left.max_odd, node.right.max_odd)
        return node
    
    def push(node):
        if node.flip and node.left:
            # Flip left and right
            node.left.max_even, node.left.max_odd = node.left.max_odd, node.left.max_even
            node.left.flip = not node.left.flip
            node.right.max_even, node.right.max_odd = node.right.max_odd, node.right.max_even
            node.right.flip = not node.right.flip
            node.flip = False
    
    def update(node, idx, val):
        if node.l == node.r == idx:
            parity = (idx) % 2
            if parity == 0:
                node.max_even = val
                node.max_odd = -float('inf')
            else:
                node.max_odd = val
                node.max_even = -float('inf')
            return
        push(node)
        mid = (node.l + node.r) // 2
        if idx <= mid:
            update(node.left, idx, val)
        else:
            update(node.right, idx, val)
        node.max_even = max(node.left.max_even, node.right.max_even)
        node.max_odd = max(node.left.max_odd, node.right.max_odd)
    
    def range_flip(node, l, r):
        if node.r < l or node.l > r:
            return
        if l <= node.l and node.r <= r:
            node.max_even, node.max_odd = node.max_odd, node.max_even
            node.flip = not node.flip
            return
        push(node)
        range_flip(node.left, l, r)
        range_flip(node.right, l, r)
        node.max_even = max(node.left.max_even, node.right.max_even)
        node.max_odd = max(node.left.max_odd, node.right.max_odd)
    
    def query_max_even(node, l, r):
        if node.r < l or node.l > r:
            return -float('inf')
        if l <= node.l and node.r <= r:
            return node.max_even
        push(node)
        left_max = query_max_even(node.left, l, r)
        right_max = query_max_even(node.right, l, r)
        return max(left_max, right_max)
    
    root = build(1, N)
    
    total = 0
    for _ in range(K):
        max_val = -float('inf')
        max_idx = -1
        current_max = query_max_even(root, 1, N)
        if current_max == -float('inf'):
            break
        max_val = current_max
        
        # Find the index with max_val
        stack = [root]
        while stack:
            node = stack.pop()
            if node.l == node.r:
                if node.max_even == max_val:
                    max_idx = node.l
                continue
            push(node)
            if node.left.max_even >= max_val:
                stack.append(node.left)
            if node.right.max_even >= max_val:
                stack.append(node.right)
        
        # Remove the max_idx
        update(root, max_idx, -float('inf'))
        # All positions to the right of max_idx have their parity flipped
        range_flip(root, max_idx+1, N)
        total += max_val
    
    print(total)

if __name__ == "__main__":
    main()
0