結果
| 問題 |
No.2423 Merge Stones
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 14:13:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,315 ms / 4,000 ms |
| コード長 | 2,631 bytes |
| コンパイル時間 | 227 ms |
| コンパイル使用メモリ | 82,284 KB |
| 実行使用メモリ | 80,032 KB |
| 最終ジャッジ日時 | 2025-06-12 14:13:53 |
| 合計ジャッジ時間 | 29,442 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 72 |
ソースコード
def main():
import sys
input = sys.stdin.read().split()
idx = 0
N = int(input[idx])
idx += 1
K = int(input[idx])
idx += 1
A = list(map(int, input[idx:idx+N]))
idx += N
C = list(map(int, input[idx:idx+N]))
idx += N
# Preprocess allowed colors for each color (0-based)
allowed = [0] * 50
for c in range(50):
for d in range(50):
if abs((c + 1) - (d + 1)) <= K:
allowed[c] |= (1 << d)
# Extend the array to handle circular cases
A_ext = A * 2
C_ext = [x - 1 for x in (C * 2)] # Convert to 0-based
# Prefix sums
prefix = [0] * (2 * N + 1)
for i in range(2 * N):
prefix[i + 1] = prefix[i] + A_ext[i]
# DP table initialized with 0 (bitmask)
dp = [[0 for _ in range(2 * N)] for _ in range(2 * N)]
for i in range(2 * N):
c = C_ext[i]
dp[i][i] = 1 << c
max_sum = max(A) # Initialize with the maximum single stone
# Precompute allowed_OR cache
allowed_OR_cache = {}
def get_allowed_OR(mask):
if mask in allowed_OR_cache:
return allowed_OR_cache[mask]
res = 0
tmp = mask
while tmp:
lb = tmp & -tmp
idx = (lb).bit_length() - 1
res |= allowed[idx]
tmp ^= lb
allowed_OR_cache[mask] = res
return res
# Fill DP table
for length in range(2, N + 1):
for l in range(2 * N):
r = l + length - 1
if r >= 2 * N:
continue
current_mask = 0
for k in range(l, r):
mask1 = dp[l][k]
mask2 = dp[k + 1][r]
if mask1 == 0 or mask2 == 0:
continue
allowed_OR1 = get_allowed_OR(mask1)
allowed_OR2 = get_allowed_OR(mask2)
mask1_allowed = allowed_OR1 & mask2
mask2_allowed = allowed_OR2 & mask1
new_mask = mask1_allowed | mask2_allowed
current_mask |= new_mask
dp[l][r] = current_mask
if current_mask != 0:
current_sum = prefix[r + 1] - prefix[l]
if current_sum > max_sum:
max_sum = current_sum
# Check all intervals of length <= N in the extended array
for l in range(2 * N):
for r in range(l, min(l + N, 2 * N)):
if dp[l][r] != 0:
current_sum = prefix[r + 1] - prefix[l]
if current_sum > max_sum:
max_sum = current_sum
print(max_sum)
if __name__ == "__main__":
main()
gew1fw