結果

問題 No.2423 Merge Stones
ユーザー lam6er
提出日時 2025-04-15 22:02:14
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,030 bytes
コンパイル時間 459 ms
コンパイル使用メモリ 81,756 KB
実行使用メモリ 89,888 KB
最終ジャッジ日時 2025-04-15 22:03:28
合計ジャッジ時間 6,598 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 10 TLE * 1 -- * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    K = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N
    C = list(map(int, input[idx:idx+N]))
    idx += N

    # Double the arrays to handle circularity
    A = A * 2
    C = C * 2

    # Precompute sum for each interval [i..j]
    sum_ = [[0]*(2*N) for _ in range(2*N)]
    for i in range(2*N):
        current_sum = 0
        for j in range(i, 2*N):
            current_sum += A[j]
            sum_[i][j] = current_sum

    # Precompute allowed masks for each color
    max_color = 50
    allowed_masks = [0] * (max_color + 1)
    for c in range(max_color + 1):
        mask = 0
        for delta in range(-K, K+1):
            nc = c + delta
            if 1 <= nc <= max_color:
                mask |= 1 << nc
        allowed_masks[c] = mask

    # Initialize DP
    dp = [[0]*(2*N) for _ in range(2*N)]
    for i in range(2*N):
        c = C[i]
        dp[i][i] = 1 << c

    max_sum = 0

    # Fill DP for intervals of length L >= 2
    for L in range(2, N+1):
        for i in range(2*N):
            j = i + L - 1
            if j >= 2*N:
                continue
            current_dp = 0
            for k in range(i, j):
                left = dp[i][k]
                right = dp[k+1][j]
                if not left or not right:
                    continue
                # Check left colors
                merged = 0
                # For each color in left, check if any color in right is allowed
                left_colors = left
                while left_colors:
                    lsb = left_colors & -left_colors
                    c1 = (lsb).bit_length() - 1
                    left_colors ^= lsb
                    allowed = allowed_masks[c1]
                    if (allowed & right) != 0:
                        merged |= 1 << c1
                # For each color in right, check if any color in left is allowed
                right_colors = right
                while right_colors:
                    lsb = right_colors & -right_colors
                    c2 = (lsb).bit_length() - 1
                    right_colors ^= lsb
                    allowed = allowed_masks[c2]
                    if (allowed & left) != 0:
                        merged |= 1 << c2
                current_dp |= merged
            dp[i][j] = current_dp
            if current_dp != 0:
                current_sum = sum_[i][j]
                if current_sum > max_sum:
                    max_sum = current_sum

    # Check intervals of length 1 to N in the original array
    for i in range(N):
        for j in range(i, i + N):
            if j >= 2*N:
                continue
            L = j - i + 1
            if L > N:
                continue
            if dp[i][j] != 0:
                current_sum = sum_[i][j]
                if current_sum > max_sum:
                    max_sum = current_sum

    print(max_sum)

if __name__ == '__main__':
    main()
0