結果

問題 No.460 裏表ちわーわ
ユーザー rpy3cpprpy3cpp
提出日時 2017-04-25 03:31:52
言語 Python3
(3.8.3 + numpy 1.14.5 + scipy 1.1.0)
結果
AC  
実行時間 24 ms / 2,000 ms
コード長 3,267 Byte
コンパイル時間 52 ms
使用メモリ 8,948 KB
最終ジャッジ日時 2020-06-30 13:07:14

テストケース

テストケース表示
入力 結果 実行時間
使用メモリ
testcase_00 AC 19 ms
6,900 KB
testcase_01 AC 18 ms
6,896 KB
testcase_02 AC 19 ms
6,896 KB
testcase_03 AC 18 ms
8,948 KB
testcase_04 AC 18 ms
8,944 KB
testcase_05 AC 20 ms
8,900 KB
testcase_06 AC 20 ms
6,900 KB
testcase_07 AC 23 ms
8,944 KB
testcase_08 AC 24 ms
6,900 KB
testcase_09 AC 18 ms
8,940 KB
testcase_10 AC 24 ms
6,896 KB
testcase_11 AC 19 ms
6,896 KB
testcase_12 AC 21 ms
8,896 KB
testcase_13 AC 23 ms
8,944 KB
testcase_14 AC 24 ms
6,896 KB
testcase_15 AC 24 ms
8,940 KB
testcase_16 AC 19 ms
6,896 KB
testcase_17 AC 24 ms
8,944 KB
testcase_18 AC 23 ms
6,904 KB
testcase_19 AC 24 ms
6,896 KB
testcase_20 AC 19 ms
6,904 KB
testcase_21 AC 24 ms
8,944 KB
testcase_22 AC 24 ms
8,940 KB
testcase_23 AC 24 ms
6,900 KB
testcase_24 AC 23 ms
6,896 KB
testcase_25 AC 17 ms
8,940 KB
testcase_26 AC 17 ms
6,896 KB
testcase_27 AC 19 ms
8,944 KB
権限があれば一括ダウンロードができます

ソースコード

diff #
def read_data():
    M, N = map(int, input().split())
    data = []
    for m in range(M):
        row = list(map(int, input().split()))
        data.append(row)
    return M, N, data

def solve(M, N, data):
    if min(M, N) < 2:
        return brute_force(M, N, data)
    if M < N:
        M, N = N, M
        data = list(map(list, zip(*data)))
    data.append([0] * N)
    bits = lst2num(data[0] + [data[1][0]])
    dp = {bits:0}
    for m in range(M):
        dp = fill_line(dp, m, M, N, data)
    if 0 in dp:
        return dp[0]
    else:
        return "Impossible"

def lst2num(lst):
    num = 0
    for b in lst:
        num *= 2
        num += b
    return num

def fill_line(dp, m, M, N, data):
    is_first = (m == 0)
    is_last = (m == M - 1)
    dp = fill_head(dp, is_first, is_last, m, N, data)
    for n in range(1, N - 1):
        dp = fill_body(dp, is_first, is_last, m, n, N, data)
    dp = fill_tail(dp, is_first, is_last, m, N, data)
    return dp

def set_mask(n, N, is_first, is_last):
    if is_first: return (1 << (N + 2 + n)) - 1
    if is_last: return (1 << (2 * N + 2)) - (1 << (n + 2))
    return (1 << (2 * N + 2)) - 1

def fill_head(dp, is_first, is_last, m, N, data):
    '''
    dp[bits] 2行+2セルがbitsとなる最短手順数
    is_first 先頭行であるか否か
    is_last  最終行であるか否か
    m 何行目か
    N 1行の長さ
    data[r][c] 盤面の情報。表ならば0、裏ならば1
    mask: 1 のところのみが有効。0 のところは無視される。    
    '''
    mask = set_mask(0, N, is_first, is_last)
    rev = 3 + (3 << N) + (3 << (2*N))
    bit = data[m + 1][1]
    dp = update_dp(dp, mask, rev, bit, 0)
    return dp


def fill_body(dp, is_first, is_last, m, n, N, data):
    mask = set_mask(n, N, is_first, is_last)
    rev = 7 + (7 << N) + (3 << (2*N))
    bit = data[m + 1][n + 1]
    UL = 0 if is_first else (1 << (2 * N + 1))
    dp = update_dp(dp, mask, rev, bit, UL)
    return dp


def fill_tail(dp, is_first, is_last, m, N, data):
    mask = set_mask(N - 1, N, is_first, is_last)
    rev = 6 + (6 << N) + (2 << (2*N))
    bit = 0 if is_last else data[m + 2][0]
    ULU = 0 if is_first else (3 << (2 * N))
    dp = update_dp(dp, mask, rev, bit, ULU)
    return dp


def update_dp(dp, mask, rev, bit, ULUs):
    new_dp = {}
    for bits, steps in dp.items():
        if bits & ULUs == ULUs:
            new_bits = ((bits * 2 + bit) ^ rev) & mask
            if (new_bits not in new_dp) or (new_dp[new_bits] > steps + 1):
                new_dp[new_bits] = steps + 1
        if bits & ULUs == 0:
            new_bits = (bits * 2 + bit) & mask
            if (new_bits not in new_dp) or (new_dp[new_bits] > steps):
                new_dp[new_bits] = steps
    return new_dp
    

def brute_force(M, N, data):
    if M != 1:
        M, N = N, M
        data = list(map(list, zip(*data)))
    data = data[0]
    bits = lst2num(data)
    record = N + 1
    mask = (1 << N) - 1
    for pat in range(1 << N):
        rev = ((pat << 1) ^ pat ^ (pat >> 1)) & mask
        if (bits ^ rev) == 0:
            record = min(record, bin(pat).count('1'))
    if record == N + 1:
        return "Impossible"
    else:
        return record

M, N, data = read_data()
print(solve(M, N, data))
0