結果
問題 |
No.2434 RAKUTAN de RAKUTAN
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:28:59 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,359 bytes |
コンパイル時間 | 237 ms |
コンパイル使用メモリ | 82,468 KB |
実行使用メモリ | 252,752 KB |
最終ジャッジ日時 | 2025-04-16 16:30:11 |
合計ジャッジ時間 | 11,518 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 18 WA * 6 |
ソースコード
import bisect def main(): import sys input = sys.stdin.read().split() idx = 0 N = int(input[idx]); idx +=1 H = int(input[idx]); idx +=1 X = int(input[idx]); idx +=1 G = int(input[idx]); idx +=1 g_list = list(map(int, input[idx:idx+G])) idx += G B = int(input[idx]); idx +=1 b_list = list(map(int, input[idx:idx+B])) idx += B # Generate all possible intervals interval_set = set() intervals = [] all_special = g_list + b_list for s in all_special: min_p = max(0, s - (X-1)) max_p = min(s -1, N - X) if min_p > max_p: continue for p in range(min_p, max_p +1): start = p +1 end = p + X -1 if (start, end) not in interval_set: interval_set.add( (start, end) ) intervals.append( (start, end) ) # Sort intervals by end intervals.sort(key=lambda x: x[1]) # Compute value for each interval and filter g_list.sort() b_list.sort() filtered_intervals = [] for start, end in intervals: # Count g left_g = bisect.bisect_left(g_list, start) right_g = bisect.bisect_right(g_list, end) count_g = right_g - left_g # Count b left_b = bisect.bisect_left(b_list, start) right_b = bisect.bisect_right(b_list, end) count_b = right_b - left_b value = count_b - count_g if value >0: filtered_intervals.append( (start, end, value) ) m = len(filtered_intervals) K = min(H, m) if m ==0: print( G - B ) return # Compute prev for each interval sorted_ends = [interval[1] for interval in filtered_intervals] prev = [ -1 ] * m for i in range(m): start_i = filtered_intervals[i][0] low = 0 high = i-1 best_j = -1 while low <= high: mid = (low + high) //2 if sorted_ends[mid] < start_i: best_j = mid low = mid +1 else: high = mid -1 prev[i] = best_j # Initialize DP dp = [ [ -float('inf') for _ in range(K+1) ] for _ in range(m) ] dp[0][0] = 0 if K >=1: if prev[0] == -1: dp[0][1] = filtered_intervals[0][2] else: if prev[0] >=0 and dp[prev[0]][0] != -float('inf'): dp[0][1] = dp[prev[0]][0] + filtered_intervals[0][2] else: dp[0][1] = -float('inf') for i in range(1, m): start_i, end_i, value_i = filtered_intervals[i] for k in range(K+1): # Not take this interval not_take = dp[i-1][k] # Take this interval take = -float('inf') if k >=1: j = prev[i] if j == -1: if k ==1: take = value_i else: if j >=0 and (k-1) >=0 and dp[j][k-1] != -float('inf'): take = dp[j][k-1] + value_i dp[i][k] = max(not_take, take) max_sum = max( dp[m-1][k] for k in range(K+1) ) if max_sum == -float('inf'): max_sum = 0 answer = (G - B) + max_sum print(answer) if __name__ == '__main__': main()