結果
| 問題 | No.2423 Merge Stones | 
| コンテスト | |
| ユーザー |  gew1fw | 
| 提出日時 | 2025-06-12 19:20:29 | 
| 言語 | PyPy3 (7.3.15) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 1,512 ms / 4,000 ms | 
| コード長 | 2,631 bytes | 
| コンパイル時間 | 342 ms | 
| コンパイル使用メモリ | 82,040 KB | 
| 実行使用メモリ | 79,952 KB | 
| 最終ジャッジ日時 | 2025-06-12 19:21:17 | 
| 合計ジャッジ時間 | 32,133 ms | 
| ジャッジサーバーID (参考情報) | judge2 / judge1 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 1 | 
| other | AC * 72 | 
ソースコード
def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    K = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N
    C = list(map(int, input[idx:idx+N]))
    idx += N
    # Preprocess allowed colors for each color (0-based)
    allowed = [0] * 50
    for c in range(50):
        for d in range(50):
            if abs((c + 1) - (d + 1)) <= K:
                allowed[c] |= (1 << d)
    # Extend the array to handle circular cases
    A_ext = A * 2
    C_ext = [x - 1 for x in (C * 2)]  # Convert to 0-based
    # Prefix sums
    prefix = [0] * (2 * N + 1)
    for i in range(2 * N):
        prefix[i + 1] = prefix[i] + A_ext[i]
    # DP table initialized with 0 (bitmask)
    dp = [[0 for _ in range(2 * N)] for _ in range(2 * N)]
    for i in range(2 * N):
        c = C_ext[i]
        dp[i][i] = 1 << c
    max_sum = max(A)  # Initialize with the maximum single stone
    # Precompute allowed_OR cache
    allowed_OR_cache = {}
    def get_allowed_OR(mask):
        if mask in allowed_OR_cache:
            return allowed_OR_cache[mask]
        res = 0
        tmp = mask
        while tmp:
            lb = tmp & -tmp
            idx = (lb).bit_length() - 1
            res |= allowed[idx]
            tmp ^= lb
        allowed_OR_cache[mask] = res
        return res
    # Fill DP table
    for length in range(2, N + 1):
        for l in range(2 * N):
            r = l + length - 1
            if r >= 2 * N:
                continue
            current_mask = 0
            for k in range(l, r):
                mask1 = dp[l][k]
                mask2 = dp[k + 1][r]
                if mask1 == 0 or mask2 == 0:
                    continue
                allowed_OR1 = get_allowed_OR(mask1)
                allowed_OR2 = get_allowed_OR(mask2)
                mask1_allowed = allowed_OR1 & mask2
                mask2_allowed = allowed_OR2 & mask1
                new_mask = mask1_allowed | mask2_allowed
                current_mask |= new_mask
            dp[l][r] = current_mask
            if current_mask != 0:
                current_sum = prefix[r + 1] - prefix[l]
                if current_sum > max_sum:
                    max_sum = current_sum
    # Check all intervals of length <= N in the extended array
    for l in range(2 * N):
        for r in range(l, min(l + N, 2 * N)):
            if dp[l][r] != 0:
                current_sum = prefix[r + 1] - prefix[l]
                if current_sum > max_sum:
                    max_sum = current_sum
    print(max_sum)
if __name__ == "__main__":
    main()
            
            
            
        