結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-02-10 14:45:04
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
AC  
実行時間 1,422 ms / 2,000 ms
コード長 4,015 bytes
コンパイル時間 382 ms
コンパイル使用メモリ 12,416 KB
実行使用メモリ 11,136 KB
最終ジャッジ日時 2025-03-05 20:39:40
合計ジャッジ時間 58,730 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

"""
yukicoder Problem: Verify Sorting Network
O(m*2^n)解法 + 32768bit並列化
"""
import math
import sys


MAX_TESTCASES = 10000
MAX_N = 27
MAX_COST = 1e8
STATE_ZERO = 0
STATE_ONE = 1
STATE_UNKNWON = 2


class IsSortingOk:
    """is_sorting_network Ok type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return True

    def __str__(self):
        return 'Yes'

    def get_data(self):
        """get data value"""
        return self.value


class IsSortingNg:
    """is_sorting_network Ng type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return False

    def __str__(self):
        return 'No'

    def get_error(self):
        """get error value"""
        return self.value


def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:

    """
    与えられたネットワークが sorting network であるかどうかを調べます。
    Args:
        n: 入力の数
        net: 比較器のリスト
    """
    assert 2 <= n
    # 0-indexed 入力の範囲を確認
    assert all(0 <= a < b < n for a, b in net)
    # 比較器の数
    m = len(net)

    unused = [True] * m
    unsorted = [False] * (n - 1)

    # (2**pbits)-bit の状態ベクトルを並列に処理する
    pbits = 15
    # (2**pbits)-bit の 1 が立っているビット列
    pfull = (1 << (1 << pbits)) - 1
    lows = []
    for i in range(pbits):
        le = ((1 << (1 << i)) - 1) << (1 << i)
        for j in range(i + 1, pbits):
            le |= (le << (1 << j))
        lows.append(le)

    # 2**n 通りの {0,1} 入力を 2**pbits 通りずつ処理する
    for i in range(1 << max(n - pbits, 0)):
        p = lows + [(pfull if ((i >> j) & 1) == 1 else 0) for j in range(n - pbits)]
        for j, (a, b) in enumerate(net):
            na = p[a] & p[b]
            if p[a] != na:
                p[a], p[b] = na, p[a] | p[b]
                unused[j] = False
        for j in range(n - 1):
            if (p[j] & ~p[j + 1]) != 0:
                unsorted[j] = True

    # ソートされていない入力パターンがある場合
    if any(unsorted):
        return IsSortingNg(unsorted)
    # すべての入力パターンでソートされている場合
    return IsSortingOk(unused)


# 黄金比 (1+sqrt(5))/2 ≒ 1.618033988749895
PHI = math.sqrt(1.25) + 0.5  # 黄金比


def main():
    """テストケースの入出力処理"""
    t = int(sys.stdin.readline())
    assert t <= MAX_TESTCASES
    cost = 0
    for _ in range(t):
        n, m = map(int, sys.stdin.readline().split())
        assert 2 <= n <= MAX_N
        assert 1 <= m <= n * (n - 1) // 2
        cost += m * PHI**n  # テストケースの計算量コスト
        assert cost <= MAX_COST
        # 1-indexed -> 0-indexed
        a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        cmps: list[tuple[int, int]] = list(zip(a, b))
        assert len(cmps) == m
        assert all(0 <= a < b < n for a, b in cmps)
        # is_sorting: 与えられたネットワークが sorting network であるかどうか
        is_sorting = is_sorting_network(n, cmps)
        print(is_sorting)  # Yes or No
        if is_sorting:
            # unused_cmp: 使われない比較器かどうか (is_sorting=True の場合のみ)
            unused_cmp = is_sorting.get_data()
            assert len(unused_cmp) == m
            print(sum(unused_cmp))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
        else:
            # unsorted_pos: ソートされない可能性のある位置 (is_sorting=False の場合のみ)
            unsorted_pos = is_sorting.get_error()
            assert len(unsorted_pos) == n - 1
            print(sum(unsorted_pos))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))


main()
0