結果

問題 No.2423 Merge Stones
ユーザー gew1fw
提出日時 2025-06-12 16:09:50
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,585 bytes
コンパイル時間 215 ms
コンパイル使用メモリ 82,652 KB
実行使用メモリ 66,048 KB
最終ジャッジ日時 2025-06-12 16:10:00
合計ジャッジ時間 7,529 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 10 TLE * 1 -- * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    N = int(data[0])
    K = int(data[1])
    A = list(map(int, data[2:2+N]))
    C = list(map(int, data[2+N:2+2*N]))
    
    # Duplicate the arrays to handle circular intervals
    A_duplicated = A + A
    C_duplicated = C + C
    
    # Precompute the allowed color masks for each color
    max_color = 50
    pre_masks = [0] * (max_color + 2)  # 0-based, 1-based colors
    for c in range(1, max_color + 1):
        for dc in range(-K, K + 1):
            nc = c + dc
            if 1 <= nc <= max_color:
                pre_masks[c] |= (1 << (nc - 1))  # since colors are 1-based
    
    # Precompute sum_duplicated[i][j]
    sum_duplicated = [[0] * (2 * N) for _ in range(2 * N)]
    for i in range(2 * N):
        for j in range(i, 2 * N):
            if j - i + 1 > N:
                sum_duplicated[i][j] = -1  # invalid
            else:
                sum_duplicated[i][j] = sum(A_duplicated[i:j+1])
    
    # Initialize DP table
    dp = [[0] * (2 * N) for _ in range(2 * N)]
    for i in range(2 * N):
        color = C_duplicated[i]
        dp[i][i] = 1 << (color - 1)  # mask for color
    
    for l in range(2, N + 1):
        for i in range(2 * N):
            j = i + l - 1
            if j >= 2 * N:
                continue
            if sum_duplicated[i][j] == -1:
                continue  # invalid interval
            
            dp[i][j] = 0  # reset for this interval
            for k in range(i, j):
                mask1 = dp[i][k]
                mask2 = dp[k+1][j]
                new_colors = 0
                for c in range(50):
                    if (mask1 & (1 << c)) != 0:
                        allowed = pre_masks[c + 1]  # c+1 is the actual color
                        if (mask2 & allowed) != 0:
                            new_colors |= (1 << c)
                    if (mask2 & (1 << c)) != 0:
                        allowed = pre_masks[c + 1]
                        if (mask1 & allowed) != 0:
                            new_colors |= (1 << c)
                dp[i][j] |= new_colors
    
    max_sum = 0
    for l in range(1, N + 1):
        for i in range(2 * N):
            j = i + l - 1
            if j >= 2 * N:
                continue
            if sum_duplicated[i][j] == -1:
                continue  # invalid
            if dp[i][j] != 0:
                current_sum = sum_duplicated[i][j]
                if current_sum > max_sum:
                    max_sum = current_sum
    print(max_sum)

if __name__ == "__main__":
    main()
0