結果

問題 No.174 カードゲーム(Hard)
ユーザー lam6er
提出日時 2025-04-16 16:31:32
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,502 bytes
コンパイル時間 187 ms
コンパイル使用メモリ 81,924 KB
実行使用メモリ 55,596 KB
最終ジャッジ日時 2025-04-16 16:33:14
合計ジャッジ時間 3,810 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 2 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from functools import lru_cache

def main():
    sys.setrecursionlimit(1 << 25)
    N, PA, PB = sys.stdin.readline().split()
    N = int(N)
    PA = float(PA)
    PB = float(PB)
    A = list(map(int, sys.stdin.readline().split()))
    B = list(map(int, sys.stdin.readline().split()))
    
    A_sorted = sorted(A)
    B_sorted = sorted(B)
    
    a_index = {a: i for i, a in enumerate(A_sorted)}
    b_index = {b: i for i, b in enumerate(B_sorted)}
    
    initial_a_mask = (1 << N) - 1
    initial_b_mask = (1 << N) - 1
    
    @lru_cache(maxsize=None)
    def get_min_and_count(mask, is_a):
        if mask == 0:
            return (None, 0)
        sorted_list = A_sorted if is_a else B_sorted
        count = bin(mask).count('1')
        for i in range(N):
            if mask & (1 << i):
                return (sorted_list[i], count)
        return (None, 0)
    
    @lru_cache(maxsize=None)
    def dp(a_mask, b_mask):
        a_min, a_count = get_min_and_count(a_mask, True)
        b_min, b_count = get_min_and_count(b_mask, False)
        if a_count == 0 or b_count == 0:
            return 0.0
        
        total = 0.0
        
        a_probs = {}
        if a_count == 1:
            a_probs[a_min] = 1.0
        else:
            a_probs[a_min] = PA
            other_prob = (1.0 - PA) / (a_count - 1)
            for i in range(N):
                a = A_sorted[i]
                if (a_mask & (1 << i)) and a != a_min:
                    a_probs[a] = a_probs.get(a, 0.0) + other_prob
        
        b_probs = {}
        if b_count == 1:
            b_probs[b_min] = 1.0
        else:
            b_probs[b_min] = PB
            other_prob = (1.0 - PB) / (b_count - 1)
            for i in range(N):
                b = B_sorted[i]
                if (b_mask & (1 << i)) and b != b_min:
                    b_probs[b] = b_probs.get(b, 0.0) + other_prob
        
        for a in a_probs:
            a_prob = a_probs[a]
            a_bit = 1 << a_index[a]
            new_a_mask = a_mask ^ a_bit
            for b in b_probs:
                b_prob = b_probs[b]
                b_bit = 1 << b_index[b]
                new_b_mask = b_mask ^ b_bit
                prob = a_prob * b_prob
                current = (a + b) if (a > b) else 0
                total += prob * (current + dp(new_a_mask, new_b_mask))
        
        return total
    
    result = dp(initial_a_mask, initial_b_mask)
    print("{0:.15f}".format(result))

if __name__ == "__main__":
    main()
0