結果

問題 No.2869 yuusaan's Knapsacks
ユーザー lam6er
提出日時 2025-03-26 15:54:17
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 677 ms / 4,500 ms
コード長 3,822 bytes
コンパイル時間 346 ms
コンパイル使用メモリ 82,376 KB
実行使用メモリ 90,548 KB
最終ジャッジ日時 2025-03-26 15:55:14
合計ジャッジ時間 8,887 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import combinations

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    M = int(input[ptr])
    ptr += 1
    e = list(map(int, input[ptr:ptr+N]))
    ptr += N
    items = []
    for j in range(M):
        v = int(input[ptr])
        w = int(input[ptr+1])
        ptr += 2
        items.append((v, w, j+1))  # j+1 is original index (1-based)

    # Generate all possible subsets, sorted by total value descending
    subsets = []
    for mask in range(1 << M):
        total_v = 0
        total_w = 0
        for j in range(M):
            if (mask >> j) & 1:
                total_v += items[j][0]
                total_w += items[j][1]
        subsets.append((mask, total_v, total_w))
    # Sort subsets by total_v descending, then mask ascending (for tie)
    subsets.sort(key=lambda x: (-x[1], x[0]))

    sum_e = sum(e)
    max_e = max(e) if e else 0

    # Precompute for each subset the items included
    for mask, total_v, total_w in subsets:
        if total_v == 0:
            continue
        if total_w > sum_e:
            continue
        # Check if all items in subset have weight <= max_e
        valid = True
        items_in_subset = []
        for j in range(M):
            if (mask >> j) & 1:
                if items[j][1] > max_e:
                    valid = False
                    break
                items_in_subset.append(items[j])
        if not valid:
            continue

        # Sort items in subset by weight descending
        items_sorted = sorted(items_in_subset, key=lambda x: (-x[1], x[2]))
        # Compute suffix sums
        suffix_weights = [0] * (len(items_sorted) + 1)
        for i in range(len(items_sorted)-1, -1, -1):
            suffix_weights[i] = suffix_weights[i+1] + items_sorted[i][1]

        # Prepare initial capacities (original order of knapsacks)
        caps = e.copy()
        # Prepare assignment: list of lists, one per knapsack
        assignment = [[] for _ in range(N)]

        # Backtracking function
        def backtrack(index, caps, assignment, suffix_weights):
            if index == len(items_sorted):
                return True, assignment
            item = items_sorted[index]
            sum_remaining = suffix_weights[index]
            sum_caps = sum(caps)
            if sum_remaining > sum_caps:
                return False, None
            max_w = item[1]
            max_cap = max(caps)
            if max_w > max_cap:
                return False, None

            # Get sorted list of knapsacks by remaining capacity descending, then index ascending
            sorted_caps = sorted([(cap, i) for i, cap in enumerate(caps)], key=lambda x: (-x[0], x[1]))

            for cap, i in sorted_caps:
                if cap >= item[1]:
                    new_caps = caps.copy()
                    new_assignment = [row.copy() for row in assignment]
                    new_caps[i] -= item[1]
                    new_assignment[i].append(item[2])  # store original item index
                    success, result_assignment = backtrack(index + 1, new_caps, new_assignment, suffix_weights)
                    if success:
                        return True, result_assignment
            return False, None

        # Start backtracking
        success, result_assignment = backtrack(0, caps, assignment, suffix_weights)
        if success:
            print(total_v)
            for knapsack in result_assignment:
                print(len(knapsack), end=' ')
                if len(knapsack) > 0:
                    print(' '.join(map(str, sorted(knapsack))), end='')
                print()
            return

    # If no subset found (unlikely)
    print(0)
    for _ in range(N):
        print(0)
    return

if __name__ == '__main__':
    main()
0