結果

問題 No.421 しろくろチョコレート
ユーザー rpy3cpp
提出日時 2016-09-09 23:34:50
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 648 ms / 2,000 ms
コード長 3,220 bytes
コンパイル時間 191 ms
コンパイル使用メモリ 82,192 KB
実行使用メモリ 83,992 KB
最終ジャッジ日時 2024-09-23 07:07:06
合計ジャッジ時間 9,323 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 65
権限があれば一括ダウンロードができます

ソースコード

diff #

def bp_match(Ss, Ts, Es):
    '''2部グラフのマッチング問題をとく。
    '''
    super_source = len(Ss) + len(Ts)
    super_target = super_source + 1
    Cs = [dict() for i in range(super_target + 1)]
    for si, ti in Es:
        Cs[si][ti] = 1
        Cs[ti][si] = 0
    for si in Ss:
        Cs[super_source][si] = 1
        Cs[si][super_source] = 0
    for ti in Ts:
        Cs[super_target][ti] = 0
        Cs[ti][super_target] = 1
    return dinic(Cs, super_target + 1, super_source, super_target)

def dinic(cf, nV, s, t):
    dist = get_distance(cf, s, t)
    while dist[t] > 0:
        df = dfs(dist, cf, s, t, float('inf'))
        while df:
            df = dfs(dist, cf, s, t, float('inf'))
        dist = get_distance(cf, s, t)
    return sum(cf[t].values())

def get_distance(cf, s, t):
    dist = [-1] * len(cf)
    dist[s] = 0
    frontiers = [s]
    while frontiers:
        new_frontiers = []
        for u in frontiers:
            for v, capacity in cf[u].items():
                if dist[v] == -1 and capacity > 0:
                    dist[v] = dist[u] + 1
                    new_frontiers.append(v)
        frontiers = new_frontiers
    return dist

def dfs(dist, cf, u, t, df):
    if u == t:
        return df
    for v, capacity in cf[u].items():
        if dist[v] > dist[u] and capacity > 0:
            new_df = dfs(dist, cf, v, t, min(df, capacity))
            if new_df > 0:
                cf[u][v] -= new_df
                cf[v][u] += new_df
                return new_df
    return 0


def read_data():
    N, M = map(int, input().split())
    Ws = []
    Bs = []
    Es = []
    for i in range(N):
        Si = input()
        for j, s in enumerate(Si):
            if s == '.':
                continue
            elif s == 'w':
                Ws.append(i * M + j)
            elif s == 'b':
                Bs.append(i * M + j)
    Bset = set(Bs)
    for w in Ws:
        i, j = divmod(w, M)
        if i > 0 and w - M in Bset:
            Es.append((w, w - M))
        if w + M in Bset:
            Es.append((w, w + M))
        if j > 0 and w - 1 in Bset:
            Es.append((w, w - 1))
        if j < M - 1 and w + 1 in Bset:
            Es.append((w, w + 1))
    return N, M, Ws, Bs, Es

def renumber(Ws, Bs, Es):
    W_old2new = dict()
    B_old2new = dict()
    for i, w in enumerate(Ws):
        W_old2new[w] = i
    for i, b in enumerate(Bs, len(Ws)):
        B_old2new[b] = i
    newEs = []
    for w, b in Es:
        newEs.append((W_old2new[w], B_old2new[b]))
    newWs = list(range(len(Ws)))
    newBs = list(range(len(Ws), len(Ws) + len(Bs)))
    return newWs, newBs, newEs

def solve(N, M, Ws, Bs, Es):
    '''
    貪欲に、並んでwbのペアをとれるだけ取る。-> 二部マッチングで求める。
    残りのバラバラのをペアにする。
    w のみ or b のみを食べる。
    '''
    Ws, Bs, Es = renumber(Ws, Bs, Es)
    nw = len(Ws)
    nb = len(Bs)
    connected_pair = bp_match(Ws, Bs, Es)
    unconnected_pair = min(nw, nb) - connected_pair
    single = nw + nb - min(nw, nb) * 2
    return 100 * connected_pair + 10 * unconnected_pair + 1 * single

N, M, Ws, Bs, Es = read_data()
print(solve(N, M, Ws, Bs, Es))
0