結果

問題 No.297 カードの数式
ユーザー lam6er
提出日時 2025-03-31 17:51:34
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,468 bytes
コンパイル時間 140 ms
コンパイル使用メモリ 82,520 KB
実行使用メモリ 57,892 KB
最終ジャッジ日時 2025-03-31 17:52:15
合計ジャッジ時間 3,014 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 1 TLE * 1 -- * 21
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations
from functools import lru_cache

def main():
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    cards = input[idx:idx + N]
    idx += N

    digits = []
    ops = []
    for c in cards:
        if c in '+-':
            ops.append(c)
        else:
            digits.append(c)
    
    O = len(ops)
    D = len(digits)
    if O == 0:
        print(0)
        return
    
    max_total = -float('inf')
    min_total = float('inf')
    
    # Generate unique operator sequences
    seen = set()
    op_sequences = []
    for p in permutations(ops):
        if p not in seen:
            seen.add(p)
            op_sequences.append(p)
    
    digits_tuple = tuple(digits)
    all_ops_generated = len(op_sequences) > 0
    if not all_ops_generated:
        return
    
    for op_seq in op_sequences:
        # Compute for max_total scenario
        required_opt_max = [True] * (O + 1)
        for i in range(1, O + 1):
            if op_seq[i - 1] == '+':
                required_opt_max[i] = True
            else:
                required_opt_max[i] = False
        
        @lru_cache(maxsize=None)
        def dp_max(group, mask):
            if group == O + 1:
                return 0 if mask == 0 else -float('inf')
            if mask == 0:
                return -float('inf')
            max_val = -float('inf')
            mask_list = [(1 << i) for i in range(D) if (mask & (1 << i))]
            n = len(mask_list)
            current_submask = 0
            for i in range(1, 1 << n):
                current_submask = 0
                cnt = 0
                for j in range(n):
                    if i & (1 << j):
                        current_submask |= mask_list[j]
                        cnt += 1
                if cnt == 0:
                    continue
                sub_digits = []
                for k in range(D):
                    if (current_submask & (1 << k)):
                        sub_digits.append(digits_tuple[k])
                # Determine the term's value
                if required_opt_max[group]:
                    sorted_digits = sorted(sub_digits, reverse=True)
                else:
                    sorted_digits = sorted(sub_digits)
                term = int(''.join(sorted_digits))
                new_mask = mask ^ current_submask
                next_val = dp_max(group + 1, new_mask)
                if next_val == -float('inf'):
                    continue
                if group == 0:
                    current_val = term + next_val
                else:
                    op = op_seq[group - 1]
                    if op == '+':
                        current_val = term + next_val
                    else:
                        current_val = -term + next_val
                if current_val > max_val:
                    max_val = current_val
            return max_val if max_val != -float('inf') else -float('inf')
        
        initial_mask = (1 << D) - 1
        max_curr = dp_max(0, initial_mask)
        if max_curr > max_total:
            max_total = max_curr
        
        # Compute for min_total scenario
        required_opt_min = [False] * (O + 1)
        required_opt_min[0] = False
        for i in range(1, O + 1):
            if op_seq[i - 1] == '+':
                required_opt_min[i] = False
            else:
                required_opt_min[i] = True
        
        @lru_cache(maxsize=None)
        def dp_min(group, mask):
            if group == O + 1:
                return 0 if mask == 0 else float('inf')
            if mask == 0:
                return float('inf')
            min_val = float('inf')
            mask_list = [(1 << i) for i in range(D) if (mask & (1 << i))]
            n = len(mask_list)
            current_submask = 0
            for i in range(1, 1 << n):
                current_submask = 0
                cnt = 0
                for j in range(n):
                    if i & (1 << j):
                        current_submask |= mask_list[j]
                        cnt += 1
                if cnt == 0:
                    continue
                sub_digits = []
                for k in range(D):
                    if (current_submask & (1 << k)):
                        sub_digits.append(digits_tuple[k])
                # Determine the term's value
                if required_opt_min[group]:
                    sorted_digits = sorted(sub_digits, reverse=True)
                else:
                    sorted_digits = sorted(sub_digits)
                term = int(''.join(sorted_digits))
                new_mask = mask ^ current_submask
                next_val = dp_min(group + 1, new_mask)
                if next_val == float('inf'):
                    continue
                if group == 0:
                    current_val = term + next_val
                else:
                    op = op_seq[group - 1]
                    if op == '+':
                        current_val = term + next_val
                    else:
                        current_val = -term + next_val
                if current_val < min_val:
                    min_val = current_val
            return min_val if min_val != float('inf') else float('inf')
        
        min_curr = dp_min(0, initial_mask)
        if min_curr < min_total:
            min_total = min_curr
    
    print(f"{max_total} {min_total}")

if __name__ == '__main__':
    main()
0