結果
| 問題 | No.2423 Merge Stones |
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 22:03:38 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,030 bytes |
| 記録 | |
| コンパイル時間 | 163 ms |
| コンパイル使用メモリ | 81,668 KB |
| 実行使用メモリ | 61,332 KB |
| 最終ジャッジ日時 | 2025-04-15 22:04:59 |
| 合計ジャッジ時間 | 6,212 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 10 TLE * 1 -- * 61 |
ソースコード
def main():
import sys
input = sys.stdin.read().split()
idx = 0
N = int(input[idx])
idx += 1
K = int(input[idx])
idx += 1
A = list(map(int, input[idx:idx+N]))
idx += N
C = list(map(int, input[idx:idx+N]))
idx += N
# Double the arrays to handle circularity
A = A * 2
C = C * 2
# Precompute sum for each interval [i..j]
sum_ = [[0]*(2*N) for _ in range(2*N)]
for i in range(2*N):
current_sum = 0
for j in range(i, 2*N):
current_sum += A[j]
sum_[i][j] = current_sum
# Precompute allowed masks for each color
max_color = 50
allowed_masks = [0] * (max_color + 1)
for c in range(max_color + 1):
mask = 0
for delta in range(-K, K+1):
nc = c + delta
if 1 <= nc <= max_color:
mask |= 1 << nc
allowed_masks[c] = mask
# Initialize DP
dp = [[0]*(2*N) for _ in range(2*N)]
for i in range(2*N):
c = C[i]
dp[i][i] = 1 << c
max_sum = 0
# Fill DP for intervals of length L >= 2
for L in range(2, N+1):
for i in range(2*N):
j = i + L - 1
if j >= 2*N:
continue
current_dp = 0
for k in range(i, j):
left = dp[i][k]
right = dp[k+1][j]
if not left or not right:
continue
# Check left colors
merged = 0
# For each color in left, check if any color in right is allowed
left_colors = left
while left_colors:
lsb = left_colors & -left_colors
c1 = (lsb).bit_length() - 1
left_colors ^= lsb
allowed = allowed_masks[c1]
if (allowed & right) != 0:
merged |= 1 << c1
# For each color in right, check if any color in left is allowed
right_colors = right
while right_colors:
lsb = right_colors & -right_colors
c2 = (lsb).bit_length() - 1
right_colors ^= lsb
allowed = allowed_masks[c2]
if (allowed & left) != 0:
merged |= 1 << c2
current_dp |= merged
dp[i][j] = current_dp
if current_dp != 0:
current_sum = sum_[i][j]
if current_sum > max_sum:
max_sum = current_sum
# Check intervals of length 1 to N in the original array
for i in range(N):
for j in range(i, i + N):
if j >= 2*N:
continue
L = j - i + 1
if L > N:
continue
if dp[i][j] != 0:
current_sum = sum_[i][j]
if current_sum > max_sum:
max_sum = current_sum
print(max_sum)
if __name__ == '__main__':
main()
lam6er