結果

問題 No.2423 Merge Stones
ユーザー gew1fw
提出日時 2025-06-12 14:13:16
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,315 ms / 4,000 ms
コード長 2,631 bytes
コンパイル時間 227 ms
コンパイル使用メモリ 82,284 KB
実行使用メモリ 80,032 KB
最終ジャッジ日時 2025-06-12 14:13:53
合計ジャッジ時間 29,442 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 72
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    K = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N
    C = list(map(int, input[idx:idx+N]))
    idx += N

    # Preprocess allowed colors for each color (0-based)
    allowed = [0] * 50
    for c in range(50):
        for d in range(50):
            if abs((c + 1) - (d + 1)) <= K:
                allowed[c] |= (1 << d)

    # Extend the array to handle circular cases
    A_ext = A * 2
    C_ext = [x - 1 for x in (C * 2)]  # Convert to 0-based

    # Prefix sums
    prefix = [0] * (2 * N + 1)
    for i in range(2 * N):
        prefix[i + 1] = prefix[i] + A_ext[i]

    # DP table initialized with 0 (bitmask)
    dp = [[0 for _ in range(2 * N)] for _ in range(2 * N)]
    for i in range(2 * N):
        c = C_ext[i]
        dp[i][i] = 1 << c

    max_sum = max(A)  # Initialize with the maximum single stone

    # Precompute allowed_OR cache
    allowed_OR_cache = {}

    def get_allowed_OR(mask):
        if mask in allowed_OR_cache:
            return allowed_OR_cache[mask]
        res = 0
        tmp = mask
        while tmp:
            lb = tmp & -tmp
            idx = (lb).bit_length() - 1
            res |= allowed[idx]
            tmp ^= lb
        allowed_OR_cache[mask] = res
        return res

    # Fill DP table
    for length in range(2, N + 1):
        for l in range(2 * N):
            r = l + length - 1
            if r >= 2 * N:
                continue
            current_mask = 0
            for k in range(l, r):
                mask1 = dp[l][k]
                mask2 = dp[k + 1][r]
                if mask1 == 0 or mask2 == 0:
                    continue

                allowed_OR1 = get_allowed_OR(mask1)
                allowed_OR2 = get_allowed_OR(mask2)

                mask1_allowed = allowed_OR1 & mask2
                mask2_allowed = allowed_OR2 & mask1
                new_mask = mask1_allowed | mask2_allowed

                current_mask |= new_mask

            dp[l][r] = current_mask
            if current_mask != 0:
                current_sum = prefix[r + 1] - prefix[l]
                if current_sum > max_sum:
                    max_sum = current_sum

    # Check all intervals of length <= N in the extended array
    for l in range(2 * N):
        for r in range(l, min(l + N, 2 * N)):
            if dp[l][r] != 0:
                current_sum = prefix[r + 1] - prefix[l]
                if current_sum > max_sum:
                    max_sum = current_sum

    print(max_sum)

if __name__ == "__main__":
    main()
0