結果

問題 No.2434 RAKUTAN de RAKUTAN
ユーザー lam6er
提出日時 2025-04-16 16:28:59
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,359 bytes
コンパイル時間 237 ms
コンパイル使用メモリ 82,468 KB
実行使用メモリ 252,752 KB
最終ジャッジ日時 2025-04-16 16:30:11
合計ジャッジ時間 11,518 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 18 WA * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx]); idx +=1
    H = int(input[idx]); idx +=1
    X = int(input[idx]); idx +=1
    
    G = int(input[idx]); idx +=1
    g_list = list(map(int, input[idx:idx+G]))
    idx += G
    
    B = int(input[idx]); idx +=1
    b_list = list(map(int, input[idx:idx+B]))
    idx += B
    
    # Generate all possible intervals
    interval_set = set()
    intervals = []
    all_special = g_list + b_list
    for s in all_special:
        min_p = max(0, s - (X-1))
        max_p = min(s -1, N - X)
        if min_p > max_p:
            continue
        for p in range(min_p, max_p +1):
            start = p +1
            end = p + X -1
            if (start, end) not in interval_set:
                interval_set.add( (start, end) )
                intervals.append( (start, end) )
    
    # Sort intervals by end
    intervals.sort(key=lambda x: x[1])
    
    # Compute value for each interval and filter
    g_list.sort()
    b_list.sort()
    filtered_intervals = []
    for start, end in intervals:
        # Count g
        left_g = bisect.bisect_left(g_list, start)
        right_g = bisect.bisect_right(g_list, end)
        count_g = right_g - left_g
        
        # Count b
        left_b = bisect.bisect_left(b_list, start)
        right_b = bisect.bisect_right(b_list, end)
        count_b = right_b - left_b
        
        value = count_b - count_g
        if value >0:
            filtered_intervals.append( (start, end, value) )
    
    m = len(filtered_intervals)
    K = min(H, m)
    if m ==0:
        print( G - B )
        return
    
    # Compute prev for each interval
    sorted_ends = [interval[1] for interval in filtered_intervals]
    prev = [ -1 ] * m
    for i in range(m):
        start_i = filtered_intervals[i][0]
        low = 0
        high = i-1
        best_j = -1
        while low <= high:
            mid = (low + high) //2
            if sorted_ends[mid] < start_i:
                best_j = mid
                low = mid +1
            else:
                high = mid -1
        prev[i] = best_j
    
    # Initialize DP
    dp = [ [ -float('inf') for _ in range(K+1) ] for _ in range(m) ]
    dp[0][0] = 0
    if K >=1:
        if prev[0] == -1:
            dp[0][1] = filtered_intervals[0][2]
        else:
            if prev[0] >=0 and dp[prev[0]][0] != -float('inf'):
                dp[0][1] = dp[prev[0]][0] + filtered_intervals[0][2]
            else:
                dp[0][1] = -float('inf')
    
    for i in range(1, m):
        start_i, end_i, value_i = filtered_intervals[i]
        for k in range(K+1):
            # Not take this interval
            not_take = dp[i-1][k]
            
            # Take this interval
            take = -float('inf')
            if k >=1:
                j = prev[i]
                if j == -1:
                    if k ==1:
                        take = value_i
                else:
                    if j >=0 and (k-1) >=0 and dp[j][k-1] != -float('inf'):
                        take = dp[j][k-1] + value_i
            dp[i][k] = max(not_take, take)
    
    max_sum = max( dp[m-1][k] for k in range(K+1) )
    if max_sum == -float('inf'):
        max_sum = 0
    answer = (G - B) + max_sum
    print(answer)
    
if __name__ == '__main__':
    main()
0