結果
問題 |
No.2423 Merge Stones
|
ユーザー |
![]() |
提出日時 | 2025-06-12 19:20:29 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,512 ms / 4,000 ms |
コード長 | 2,631 bytes |
コンパイル時間 | 342 ms |
コンパイル使用メモリ | 82,040 KB |
実行使用メモリ | 79,952 KB |
最終ジャッジ日時 | 2025-06-12 19:21:17 |
合計ジャッジ時間 | 32,133 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 72 |
ソースコード
def main(): import sys input = sys.stdin.read().split() idx = 0 N = int(input[idx]) idx += 1 K = int(input[idx]) idx += 1 A = list(map(int, input[idx:idx+N])) idx += N C = list(map(int, input[idx:idx+N])) idx += N # Preprocess allowed colors for each color (0-based) allowed = [0] * 50 for c in range(50): for d in range(50): if abs((c + 1) - (d + 1)) <= K: allowed[c] |= (1 << d) # Extend the array to handle circular cases A_ext = A * 2 C_ext = [x - 1 for x in (C * 2)] # Convert to 0-based # Prefix sums prefix = [0] * (2 * N + 1) for i in range(2 * N): prefix[i + 1] = prefix[i] + A_ext[i] # DP table initialized with 0 (bitmask) dp = [[0 for _ in range(2 * N)] for _ in range(2 * N)] for i in range(2 * N): c = C_ext[i] dp[i][i] = 1 << c max_sum = max(A) # Initialize with the maximum single stone # Precompute allowed_OR cache allowed_OR_cache = {} def get_allowed_OR(mask): if mask in allowed_OR_cache: return allowed_OR_cache[mask] res = 0 tmp = mask while tmp: lb = tmp & -tmp idx = (lb).bit_length() - 1 res |= allowed[idx] tmp ^= lb allowed_OR_cache[mask] = res return res # Fill DP table for length in range(2, N + 1): for l in range(2 * N): r = l + length - 1 if r >= 2 * N: continue current_mask = 0 for k in range(l, r): mask1 = dp[l][k] mask2 = dp[k + 1][r] if mask1 == 0 or mask2 == 0: continue allowed_OR1 = get_allowed_OR(mask1) allowed_OR2 = get_allowed_OR(mask2) mask1_allowed = allowed_OR1 & mask2 mask2_allowed = allowed_OR2 & mask1 new_mask = mask1_allowed | mask2_allowed current_mask |= new_mask dp[l][r] = current_mask if current_mask != 0: current_sum = prefix[r + 1] - prefix[l] if current_sum > max_sum: max_sum = current_sum # Check all intervals of length <= N in the extended array for l in range(2 * N): for r in range(l, min(l + N, 2 * N)): if dp[l][r] != 0: current_sum = prefix[r + 1] - prefix[l] if current_sum > max_sum: max_sum = current_sum print(max_sum) if __name__ == "__main__": main()