結果

問題 No.2423 Merge Stones
ユーザー lam6er
提出日時 2025-04-16 15:43:54
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,088 bytes
コンパイル時間 279 ms
コンパイル使用メモリ 81,852 KB
実行使用メモリ 87,128 KB
最終ジャッジ日時 2025-04-16 15:46:54
合計ジャッジ時間 7,079 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 10 TLE * 1 -- * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    K = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N
    C = list(map(int, input[idx:idx+N]))
    
    # Duplicate the array to handle circularity
    A += A
    C += C
    
    # Precompute allowed masks for each color (1-based)
    max_color = 50
    allowed_masks = [0] * (max_color + 2)  # 0..50
    for c in range(1, max_color + 1):
        lower = max(1, c - K)
        upper = min(max_color, c + K)
        mask = 0
        for d in range(lower, upper + 1):
            mask |= 1 << d
        allowed_masks[c] = mask
    
    # Precompute prefix sums
    prefix = [0] * (2 * N + 1)
    for i in range(2 * N):
        prefix[i + 1] = prefix[i] + A[i]
    
    # Initialize DP table
    dp = [[0] * (2 * N) for _ in range(2 * N)]
    for i in range(2 * N):
        c = C[i]
        dp[i][i] = 1 << c
    
    # Fill DP table
    for l in range(2, N + 1):
        for i in range(2 * N - l + 1):
            j = i + l - 1
            current_mask = 0
            for k in range(i, j):
                left = dp[i][k]
                right = dp[k + 1][j]
                if left == 0 or right == 0:
                    continue
                # Iterate over all possible colors in left
                c1 = 1
                while c1 <= max_color:
                    if (left & (1 << c1)) == 0:
                        c1 += 1
                        continue
                    allowed = allowed_masks[c1]
                    overlap = right & allowed
                    if overlap != 0:
                        current_mask |= (1 << c1) | overlap
                    c1 += 1
            dp[i][j] = current_mask
    
    # Find the maximum sum
    max_sum = 0
    for i in range(2 * N):
        for j in range(i, min(i + N, 2 * N)):
            if dp[i][j] != 0:
                s = prefix[j + 1] - prefix[i]
                if s > max_sum:
                    max_sum = s
    print(max_sum)

if __name__ == '__main__':
    main()
0