結果

問題 No.2026 Yet Another Knapsack Problem
ユーザー lam6er
提出日時 2025-03-20 20:30:46
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,988 bytes
コンパイル時間 175 ms
コンパイル使用メモリ 82,656 KB
実行使用メモリ 282,340 KB
最終ジャッジ日時 2025-03-20 20:32:30
合計ジャッジ時間 24,124 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 30 TLE * 1 -- * 11
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    N = int(data[idx])
    idx +=1
    items = []
    v1 = 0
    for i in range(N):
        c = int(data[idx])
        v = int(data[idx+1])
        idx +=2
        if i ==0:
            v1 = v
            c1 = c
        else:
            items.append( (i+1, c, v) )
    
    # Step 2: Process other items with DP using binary optimization
    # other_dp is a dictionary of (t, w) to max_value
    from collections import defaultdict
    other_dp = defaultdict(lambda: -float('inf'))
    other_dp[(0, 0)] = 0

    for i, (weight, cnt, val) in enumerate(items):
        # Binary decomposition of cnt
        k = 1
        temp = cnt
        parts = []
        while temp >0:
            if temp >=k:
                parts.append(k)
                temp -=k
                k *=2
            else:
                parts.append(temp)
                temp =0
        # Process each part
        for x in parts:
            new_states = defaultdict(lambda: -float('inf'))
            # Iterate current states
            for (t_prev, w_prev), v_prev in other_dp.items():
                if v_prev == -float('inf'):
                    continue
                new_t = t_prev + x
                new_w = w_prev + x * weight
                if new_t > N or new_w > N:
                    continue
                new_v = v_prev + x * val
                if new_v > new_states[(new_t, new_w)]:
                    new_states[(new_t, new_w)] = new_v
            # Merge new states into other_dp
            for (t, w), v in new_states.items():
                if v > other_dp[(t, w)]:
                    other_dp[(t, w)] = v

    # Preprocess other_dp into a structure for each t: list of (w, max_v), sorted by w
    # And for each t, compute prefix max
    preprocessed = defaultdict(list)
    for (t, w) in other_dp:
        v = other_dp[(t, w)]
        if v == -float('inf'):
            continue
        preprocessed[t].append( (w, v) )
    
    # Sort each list by w, and compute prefix max
    max_for_t = {}
    for t in preprocessed:
        lst = sorted(preprocessed[t], key=lambda x: x[0])
        # Compute prefix max
        prefix_max = []
        current_max = -float('inf')
        new_lst = []
        # Merge duplicates in w (keep the best v)
        prev_w = -1
        best_v = -float('inf')
        for w, v in lst:
            if w != prev_w:
                if prev_w != -1:
                    new_lst.append( (prev_w, best_v) )
                prev_w = w
                best_v = v
            else:
                best_v = max(best_v, v)
        new_lst.append( (prev_w, best_v) )
        # Now, compute prefix max
        current_max = -float('inf')
        filtered = []
        for w, v in new_lst:
            if v > current_max:
                current_max = v
            filtered.append( (w, current_max) )
        # Update the max_for_t
        ws = [x[0] for x in filtered]
        vs = [x[1] for x in filtered]
        max_for_t[t] = (ws, vs)
    
    # Step 4: compute answers for each k from 1 to N
    for k in range(1, N+1):
        max_val = -float('inf')
        t_max = min(k, N -k)
        # We need t <= t_max and t >=0
        for t in range(0, t_max +1):
            k1 = k -t
            if k1 <0:
                continue
            other_max_w = N -k1
            # For other items' t: find the maximum v where w <= other_max_w and w >= 2*t
            if t not in max_for_t:
                if t ==0 and other_max_w >=0:
                    # 0 other items, the rest are k items of type1
                    candidate = 0 + k * v1
                    max_val = max(max_val, candidate)
                continue
            ws, vs = max_for_t[t]
            if not ws:
                continue
            # w must >= 2*t and <= other_max_w
            min_w = 2*t
            max_w = other_max_w
            if max_w < min_w:
                continue
            # find the largest w in ws <= max_w and >= min_w
            # find the first index >= min_w
            left = bisect.bisect_left(ws, min_w)
            # find all elements from left onwards where w <= max_w
            right = bisect.bisect_right(ws, max_w) -1
            if right <0:
                continue
            # the maximum in [left, right] is vs[right], since vs is increasing
            if left > right:
                continue
            if left > len(vs) -1:
                continue
            current_max_v = vs[right]
            total_v = current_max_v + k1 * v1
            if total_v > max_val:
                max_val = total_v
        # If no other items are selected, check if we can select all k items from type1
        # Their weight would be k *1 =k <=N
        if k *1 <=N:
            candidate = k * v1
            if candidate > max_val:
                max_val = candidate
        print(max_val)

if __name__ == "__main__":
    main()
0