結果

問題 No.174 カードゲーム(Hard)
ユーザー gew1fw
提出日時 2025-06-12 14:18:59
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,451 bytes
コンパイル時間 188 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 64,444 KB
最終ジャッジ日時 2025-06-12 14:19:06
合計ジャッジ時間 3,783 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 2 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from functools import lru_cache

def main():
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx]); idx +=1
    PA = float(input[idx]); idx +=1
    PB = float(input[idx]); idx +=1
    A = list(map(int, input[idx:idx+N]))
    idx +=N
    B = list(map(int, input[idx:idx+N]))
    
    A_sorted = sorted(A)
    B_sorted = sorted(B)
    
    a_index = {a: i for i, a in enumerate(A_sorted)}
    b_index = {b: i for i, b in enumerate(B_sorted)}
    
    a_all = (1 << N) - 1
    b_all = (1 << N) - 1
    
    @lru_cache(maxsize=None)
    def get_min_indices(mask, is_A):
        if is_A:
            arr = A_sorted
        else:
            arr = B_sorted
        min_val = None
        min_idx = -1
        count = 0
        indices = []
        for i in range(N):
            if (mask >> i) & 1:
                val = arr[i]
                if min_val is None or val < min_val:
                    min_val = val
                    min_idx = i
                indices.append(i)
        return min_idx, indices
    
    @lru_cache(maxsize=None)
    def dp(amask, bmask):
        if amask == 0 and bmask == 0:
            return 0.0
        a_min, a_indices = get_min_indices(amask, True)
        b_min, b_indices = get_min_indices(bmask, False)
        m = len(a_indices)
        n = len(b_indices)
        if m == 0 or n == 0:
            return 0.0
        
        total = 0.0
        
        for a_bit in a_indices:
            a_val = A_sorted[a_bit]
            if m == 1:
                prob_a = 1.0
            else:
                if a_bit == a_min:
                    prob_a = PA
                else:
                    prob_a = (1.0 - PA) / (m - 1)
            
            for b_bit in b_indices:
                b_val = B_sorted[b_bit]
                if n == 1:
                    prob_b = 1.0
                else:
                    if b_bit == b_min:
                        prob_b = PB
                    else:
                        prob_b = (1.0 - PB) / (n - 1)
                
                prob = prob_a * prob_b
                new_amask = amask & ~(1 << a_bit)
                new_bmask = bmask & ~(1 << b_bit)
                contrib = (a_val + b_val) if (a_val > b_val) else 0
                total += prob * (contrib + dp(new_amask, new_bmask))
        
        return total
    
    result = dp(a_all, b_all)
    print("{0:.15f}".format(result))

if __name__ == '__main__':
    main()
0