結果

問題 No.2617 容量3のナップザック
ユーザー gew1fw
提出日時 2025-06-12 21:36:23
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,783 bytes
コンパイル時間 290 ms
コンパイル使用メモリ 81,972 KB
実行使用メモリ 246,928 KB
最終ジャッジ日時 2025-06-12 21:38:53
合計ジャッジ時間 10,657 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 12 WA * 28
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    N = int(data[0])
    K = int(data[1])
    seed = int(data[2])
    a = int(data[3])
    b = int(data[4])
    m = int(data[5])
    
    # Generate f sequence
    f = [0] * (2 * N + 2)  # f[1] to f[2N]
    f[1] = seed
    for i in range(1, 2 * N):
        f[i+1] = (a * f[i] + b) % m
    
    # Compute w and v for each item
    w1 = []
    w2 = []
    w3 = []
    for i in range(1, N+1):
        fi = f[i]
        wi = fi % 3 + 1
        fNi = f[N + i]
        vi = wi * fNi
        if wi == 1:
            w1.append(vi)
        elif wi == 2:
            w2.append(vi)
        else:
            w3.append(vi)
    
    # Sort each group in descending order
    w1_sorted = sorted(w1, reverse=True)
    w2_sorted = sorted(w2, reverse=True)
    w3_sorted = sorted(w3, reverse=True)
    
    # Compute prefix sum for w1
    prefix = [0]
    current = 0
    for val in w1_sorted:
        current += val
        prefix.append(current)
    
    i1 = 0
    i2 = 0
    i3 = 0
    total = 0
    len_w1 = len(w1_sorted)
    len_w2 = len(w2_sorted)
    len_w3 = len(w3_sorted)
    
    for _ in range(K):
        a_val = w3_sorted[i3] if i3 < len_w3 else 0
        b_val = 0
        if i2 < len_w2 and i1 < len_w1:
            b_val = w2_sorted[i2] + w1_sorted[i1]
        c_val = 0
        if i1 + 3 <= len_w1:
            c_val = prefix[i1 + 3] - prefix[i1]
        
        max_val = max(a_val, b_val, c_val)
        if max_val == 0:
            break
        
        total += max_val
        
        if max_val == a_val:
            i3 += 1
        elif max_val == b_val:
            i2 += 1
            i1 += 1
        else:
            i1 += 3
    
    print(total)

if __name__ == '__main__':
    main()
0