結果

問題 No.200 カードファイト!
ユーザー rpy3cpprpy3cpp
提出日時 2015-08-10 07:47:04
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
AC  
実行時間 59 ms / 2,000 ms
コード長 7,925 bytes
コンパイル時間 290 ms
コンパイル使用メモリ 13,440 KB
実行使用メモリ 11,776 KB
最終ジャッジ日時 2024-10-15 09:00:28
合計ジャッジ時間 2,458 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
11,648 KB
testcase_01 AC 43 ms
11,776 KB
testcase_02 AC 39 ms
11,648 KB
testcase_03 AC 38 ms
11,776 KB
testcase_04 AC 38 ms
11,776 KB
testcase_05 AC 59 ms
11,776 KB
testcase_06 AC 51 ms
11,648 KB
testcase_07 AC 39 ms
11,648 KB
testcase_08 AC 45 ms
11,520 KB
testcase_09 AC 38 ms
11,648 KB
testcase_10 AC 38 ms
11,776 KB
testcase_11 AC 38 ms
11,648 KB
testcase_12 AC 38 ms
11,648 KB
testcase_13 AC 38 ms
11,648 KB
testcase_14 AC 38 ms
11,648 KB
testcase_15 AC 38 ms
11,648 KB
testcase_16 AC 38 ms
11,776 KB
testcase_17 AC 38 ms
11,648 KB
testcase_18 AC 37 ms
11,648 KB
testcase_19 AC 38 ms
11,776 KB
testcase_20 AC 51 ms
11,776 KB
testcase_21 AC 38 ms
11,648 KB
testcase_22 AC 38 ms
11,776 KB
testcase_23 AC 37 ms
11,776 KB
testcase_24 AC 38 ms
11,648 KB
testcase_25 AC 38 ms
11,648 KB
testcase_26 AC 37 ms
11,520 KB
testcase_27 AC 38 ms
11,648 KB
testcase_28 AC 38 ms
11,776 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import fractions
import math

memo = dict()
def read_data():
    N = int(input())
    A = int(input())
    Bs = list(map(int, input().split()))
    C = int(input())
    Ds = list(map(int, input().split()))
    return N, A, Bs, C, Ds

def compress(lstA, lstB):
    '''数値のリスト lstA, lstB を入力として、ai, bi の大小関係に基づく座標圧縮を行う。
    ai > bi が成り立つか否かに着目している。
    ex. lstA = [1,3,7,9,10]  ->[1]   [3]       [7,9,10]
        lstB = [2,3,5,6,11]  ->   [2]   [3,5,6]        [11]
    変換後は、以下のようになる。
    freqA = [1,1,3]
    freqB = [1,3,1]
    '''
    lstA.sort()
    lstB.sort()
    lsts = [lstA, lstB]
    lens = [len(lstA), len(lstB)]
    freqs = [[], []]
    idxs = [0, 0]
    flag = 1
    while True:
        count = 0
        flag = 1 - flag
        while ((flag == 1 and lsts[0][idxs[0]] > lsts[1][idxs[1]]) or 
               (flag == 0 and lsts[0][idxs[0]] <= lsts[1][idxs[1]])):
            idxs[flag] += 1
            count += 1
            if idxs[flag] == lens[flag]:
                freqs[flag].append(count)
                freqs[1-flag].append(lens[1-flag] - idxs[1-flag])
                if len(freqs[1]) < len(freqs[0]):
                    freqs[1].append(0)
                return tuple(freqs)
        freqs[flag].append(count)


def take_firstN(lst, n):
    '''数値を要素とするリスト lst の先頭から、合計が n になるまで、要素をとってくる。
       lst の合計が n に満たない場合は、例外を出す。
    ex. take_firstN([2,3,0,4,1], 7) = [2, 3, 0, 2, 0]
        take_firstN([2,3,0,4,1], 100) = [2, 3, 0, 4, 1]
    '''
    ret = [0] * len(lst)
    cum = 0
    for i, a in enumerate(lst):
        if cum + a >= n:
            ret[i] = n - cum
            return tuple(ret)
        ret[i] = a
        cum += a
    raise RuntimeError('sum(lst)={} must not be smaller than n={}.'.format(sum(lst), n))

def take_lastN(lst, n):
    '''take_firstN() の逆。lstの後ろから、合計がnになるまで要素をとってくる。
    '''
    ret = [0] * len(lst)
    cum = 0
    i = len(lst)
    for a in lst[::-1]:
        i -= 1
        if cum + a >= n:
            ret[i] = n - cum
            return tuple(ret)
        ret[i] = a
        cum += a
    raise RuntimeError('sum(lst)={} must not be smaller than n={}.'.format(sum(lst), n))

    
def solve(N, A, Bs, C, Ds):
    '''
    Bs, Ds の最小公倍数で1ブロック分とする。これの繰り返しパターンの計算を避けるため、
    まずは、これを切り出す。
    残りの部分について、Bsの繰り返し回数 + Bsの余り、とDsの繰り返し回数_Dsの余り を処理する。
    Bsの余り部分は、Bsから大きいのを持ってくればよい。
    Dsの余り部分は、Dsから小さいのを持ってくればよい。
    Bs 1つ目について、Ds 何個か + Ds 先頭いくつか
    とする。先頭いくつかを選ぶパターンを全て総当たりする。
    Bs 2つ目について、Ds 残りいくつか+Ds何個か+Ds先頭いくつか
    とする。残りいくつかは、その前のステップで決まる。Ds先頭いくつかを先ほどと同じように総当たりする。
    深さ優先探索での、奥の方の結果は、メモ化して、同じ探索を繰り返さないようにする。
    (さらには、探索必要なく、貪欲法で良い模様。。)
    '''
    Bs.sort()
    Ds.sort()
    minB = Bs[0]
    maxD = Ds[-1]
    if minB > maxD: return N
    maxB = Bs[-1]
    minD = Ds[0]
    if maxB <= minD: return 0
    lcm = A * C // math.gcd(A, C)
    blocks, residue = divmod(N, lcm)
    freqBs, freqDs = compress(Bs, Ds)
    result = 0
    if blocks:
        result += blocks * dfs(lcm, A, freqBs, C, freqDs)
    if residue:
        result += dfs(residue, A, freqBs, C, freqDs)
    return result


def dfs(N, nB, Bs, nD, Ds):
    '''A, C の大きさによって場合分け。本当は、1つにまとめたい。。
    '''
    global memo
    memo = dict()
    if nB >= nD:
        return dfsAC(N, nB, Bs, nD, Ds, [0] * len(Bs))
    else:
        return dfsCA(N, nB, Bs, nD, Ds, [0] * len(Bs))
        

def dfsCA(N, nB, Bs, nD, Ds, B_head):
    ''' nB < nD , N <= lcm(nB, nD) は保証されているとする。
    '''
    global memo   # メモ化
    if N < nD:
        nD = N
        Ds = take_firstN(Ds, N)
    len_head = sum(B_head)
    if N < len_head:
        len_head = N
        B_head = take_lastN(B_head, N)
    if (N, tuple(B_head)) in memo:
        return memo[N, tuple(B_head)]

    # nD = len_head + n_body * nB + len_tail と分解する。
    n_body, len_tail = divmod(nD - len_head, nB)
    Bs_head_body = [h + b * n_body for h, b in zip(B_head, Bs)]

    # 最後の1区画のときは、B_tail として、Bsの大きいのを与えて計算すればよい。
    if N == nD:
        B_tail = take_lastN(Bs, len_tail)
        Bss = [hb + t for hb, t in zip(Bs_head_body, B_tail)]
        return count_win(Bss, Ds)

    # B_tail を総当たりで試す。
    record = 0
    for B_tail in generate_tails(Bs, nB, len_tail, len(Bs)):
        B_new_head = [b - t for b, t in zip(Bs, B_tail)]
        Bss = [hb + t for hb, t in zip(Bs_head_body, B_tail)]
        score = count_win(Bss, Ds)
        score += dfsCA(N - nD, nB, Bs, nD, Ds, B_new_head)
        if record < score: record = score
        if record == N: break   # 全勝なら探索打ち切り
    memo[N, tuple(B_head)] = record
    return record


def dfsAC(N, nB, Bs, nD, Ds, D_head):
    ''' nB >= nD , N <= lcm(nB, nD) は保証されているとする。
    '''
    global memo   # メモ化
    if N < nB:
        nB = N
        Bs = take_lastN(Bs, N)
    len_head = sum(D_head)
    if N < len_head:
        len_head = N
        D_head = take_firstN(D_head, N)
    if (N, tuple(D_head)) in memo:
        return memo[N, tuple(D_head)]

    # nB = len_head + n_body * nD + len_tail と分解する。
    n_body, len_tail = divmod(nB - len_head, nD)
    Ds_head_body = [h + d * n_body for h, d in zip(D_head, Ds)]

    # 最後の1区画のとき
    if N == nB:
        D_tail = take_firstN(Ds, len_tail)
        Dss = [hb + t for hb, t in zip(Ds_head_body, D_tail)]
        return count_win(Bs, Dss)

    # 境界部分を総当たりで試す。
    record = 0
    for D_tail in generate_tails(Ds, nD, len_tail, len(Ds)):
        D_new_head = [d - t for d, t in zip(Ds, D_tail)]
        Dss = [hb + t for hb, t in zip(Ds_head_body, D_tail)]
        score = count_win(Bs, Dss)
        score += dfsAC(N - nB, nB, Bs, nD, Ds, D_new_head)
        if record < score: record = score
        if record == N: break   # 全勝なら探索打ち切り
    memo[N, tuple(D_head)] = record
    return record


def generate_tails(freq, n, k, idx):
    '''頻度を表す数値のリスト freq から、合計 k 個の要素を抽出したリスト を列挙する。
    '''
    if n == k:
        yield freq[:idx]
        return
    idx -= 1
    f = freq[idx]
    begin = max(f + k - n, 0)
    end = 1 + min(f, k)
    for i in range(begin, end):
        for lst in generate_tails(freq, n - f, k - i, idx):
            yield lst + [i]


def count_win(f0, f1):
    '''ai > bj となる組み合わせを、最大何個作れるかを返す。
    f0, f1 は、インデックス位置の数が何個あるかを表した頻度のリスト。
    f0 = [1,2,3] => 元の数字のリスト [0,1,1,2,2,2]    
    f1 = [2,1,3] => 元の数字のリスト [0,0,1,2,2,2]
    '''
    count = 0
    stock = 0
    for a, b in zip(f0[1:], f1[:-1]):
        stock += b
        if a > stock:
            count += stock
            stock = 0
        else:
            count += a
            stock -= a
    return count


if __name__ == '__main__':
    N, A, Bs, C, Ds = read_data()
    print(solve(N, A, Bs, C, Ds))
0