結果

問題 No.3097 Azuki Kurai
ユーザー RiRinbaru
提出日時 2025-03-26 16:11:37
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,183 bytes
コンパイル時間 495 ms
コンパイル使用メモリ 81,920 KB
実行使用メモリ 78,592 KB
最終ジャッジ日時 2025-03-26 16:11:46
合計ジャッジ時間 8,348 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other WA * 32
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

N_MAX = 10
M_MAX = 2000
bit_N_MAX = 1 << N_MAX
sup = 1 << 60

bit = [1 << i for i in range(25)]

def chmin(a, b):
    return min(a, b)

def solve1_ac(N, M, K, A, B):
    dp = [[0] * bit_N_MAX for _ in range(2)]
    n = [[0] * bit_N_MAX for _ in range(N)]
    q = [[[] for _ in range(bit_N_MAX)] for _ in range(N)]
    
    for i in range(N):
        for k in range(1 << N):
            n[i][k] = 0
            valid_states = []
            for kk in [k ^ bit[i]] if (k & bit[i]) else [k]:
                l = kk
                while l < (1 << N):
                    if (l & bit[i]) == 0:
                        if all(
                            (k & bit[(j - 1 + N) % N]) == 0 and (k & bit[j]) == 0 and
                            (k & bit[(j + 1) % N]) == 0 and (l & bit[j]) != 0
                            for j in range(N)
                        ):
                            valid_states.append((l, sum(
                                1 for j in range(N)
                                if (k & bit[j]) and (
                                    (i != (j - 1 + N) % N and (l & bit[(j - 1 + N) % N]) == 0) or
                                    (i != (j + 1) % N and (l & bit[(j + 1) % N]) == 0)
                                )
                            )))
                    l = (l + 1) | kk
            q[i][k] = valid_states
            n[i][k] = len(valid_states)
    
    for k in range(1 << N):
        dp[0][k] = sum(A[j] for j in range(N) if (k & bit[j]) == 0)
    
    cur, prev = 1, 0
    ans = [0] * (M + 1)
    
    for i in range(1, M + 1):
        dp[cur] = [sup] * bit_N_MAX
        
        for k in range(1 << N):
            for l, cost in q[B[i]][k]:
                dp[cur][l] = chmin(dp[cur][l], dp[prev][k] + K * cost)
        
        ans[i] = dp[cur][0]
        cur, prev = prev, cur
    
    return ans

def main():
    N, M, K = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    B = list(map(lambda x: int(x) - 1, sys.stdin.readline().split()))
    ans = solve1_ac(N, M, K, A, [0] + B)
    
    for i in range(1, M + 1):
        print(ans[i])

if __name__ == "__main__":
    main()
0