結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-03-08 20:00:54
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 860 ms / 2,000 ms
コード長 5,792 bytes
コンパイル時間 403 ms
コンパイル使用メモリ 81,980 KB
実行使用メモリ 140,388 KB
最終ジャッジ日時 2025-03-08 20:01:26
合計ジャッジ時間 26,928 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

"""
yukicoder Problem: Verify Sorting Network
"""
import sys


SHOW_PROGRESS = True
PROGRESS_THRESHOLD = 28
DICT_THRESHOLD = 20


class IsSortingOk:
    """is_sorting_network Ok type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return True

    def __str__(self):
        return 'Yes'

    def get_data(self):
        """get data value"""
        return self.value


class IsSortingNg:
    """is_sorting_network Ng type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return False

    def __str__(self):
        return 'No'

    def get_error(self):
        """get error value"""
        return self.value


def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:
    """
    Checks if the given network is a sorting network.
    Operates in time complexity O(m * phi**n). phi is the golden ratio 1.618...
    """
    assert 2 <= n
    # Check the range of 0-indexed inputs
    assert all(0 <= a < b < n for a, b in net)
    # Number of comparators
    m = len(net)
    # Record whether the comparator is ever used
    unused_cmp = []
    # Record unsorted positions
    unsorted_i = 0
    # show_progress
    show_progress = SHOW_PROGRESS and n >= PROGRESS_THRESHOLD
    if DICT_THRESHOLD <= n:
        queue: set[tuple[int, int]] = set()
        # Initial state is all '?' = indeterminate: not determined to be 0 or 1
        queue.add(((1 << n) - 1, (1 << n) - 1))
        for i, (a, b) in enumerate(net):
            queue_next: set[tuple[int, int]] = set()
            unused_f = True
            for (z, o) in queue:
                if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
                    queue_next.add((z, o))
                elif ((z >> a) & 1) == 1 and ((o >> b) & 1) == 1:
                    unused_f = False
                    qz, qo, z = z, (o ^ (1 << a) ^ (1 << b)), (z ^ (1 << b))
                    if (qo & (qz >> 1)) != 0:
                        queue_next.add((qz, qo))
                    if (o & (z >> 1)) != 0:
                        queue_next.add((z, o))
                else:
                    unused_f = False
                    xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
                    z, o = (z ^ ((xz << a) | (xz << b))), (o ^ ((xo << a) | (xo << b)))
                    if (o & (z >> 1)) != 0:
                        queue_next.add((z, o))
            unused_cmp.append(unused_f)
            queue = queue_next
            if show_progress:
                percent = i * 100 // m
                sys.stderr.write(f'{percent}%\r')
        for z, o in queue:
            unsorted_i |= (o & (z >> 1))
    else:
        # Initial state is all '?' = indeterminate: not determined to be 0 or 1
        queue: list[tuple[int, int]] = [((1 << n) - 1, (1 << n) - 1)]
        for i, (a, b) in enumerate(net):
            queue_next: list[tuple[int, int]] = []
            unused_f = True
            for (z, o) in queue:
                if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
                    queue_next.append((z, o))
                elif ((z >> a) & 1) == 1 and ((o >> b) & 1) == 1:
                    unused_f = False
                    qz, qo, z = z, (o ^ (1 << a) ^ (1 << b)), (z ^ (1 << b))
                    if (qo & (qz >> 1)) != 0:
                        queue_next.append((qz, qo))
                    if (o & (z >> 1)) != 0:
                        queue_next.append((z, o))
                else:
                    unused_f = False
                    xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
                    z, o = (z ^ ((xz << a) | (xz << b))), (o ^ ((xo << a) | (xo << b)))
                    if (o & (z >> 1)) != 0:
                        queue_next.append((z, o))
            unused_cmp.append(unused_f)
            queue = queue_next
            if show_progress:
                percent = i * 100 // m
                sys.stderr.write(f'{percent}%\r')
        for z, o in queue:
            unsorted_i |= (o & (z >> 1))
    if show_progress:
        sys.stderr.write('\n')
    # Verify that the number of search branches matches the Fibonacci sequence value
    # If there are unsorted branches
    if unsorted_i != 0:
        unsorted_pos = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)]
        return IsSortingNg(unsorted_pos)
    # If all branches are sorted
    return IsSortingOk(unused_cmp)


def main():
    """Input and output processing for test cases"""
    t = int(sys.stdin.readline())
    for _ in range(t):
        n, m = map(int, sys.stdin.readline().split())
        # 1-indexed -> 0-indexed
        a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        cmps: list[tuple[int, int]] = list(zip(a, b))
        assert len(cmps) == m
        assert all(0 <= a < b < n for a, b in cmps)
        # is_sorting: Check if the given network is a sorting network
        is_sorting = is_sorting_network(n, cmps)
        print(is_sorting)  # Yes or No
        if is_sorting:
            # unused_cmp: Whether the comparator is unused (only if is_sorting=True)
            unused_cmp = is_sorting.get_data()
            assert len(unused_cmp) == m
            print(sum(unused_cmp))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
        else:
            # unsorted_pos: Positions that may not be sorted (only if is_sorting=False)
            unsorted_pos = is_sorting.get_error()
            assert len(unsorted_pos) == n - 1
            print(sum(unsorted_pos))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))


main()
0