結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-03-19 00:46:28
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,164 ms / 2,000 ms
コード長 10,638 bytes
コンパイル時間 356 ms
コンパイル使用メモリ 82,380 KB
実行使用メモリ 99,184 KB
最終ジャッジ日時 2025-03-19 00:46:55
合計ジャッジ時間 26,530 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

"""
yukicoder Problem: Verify Sorting Network
"""
import functools
import math
import sys
import time


PROGRESS_THRESHOLD = 28
GC_THRESHOLD = 1000000
UNLIMITED = True
MAX_TESTCASES = 1000
MAX_N = 27
MAX_COST = 1e8


class Dsu:
    """Disjoint Set Union by size"""
    def __init__(self, n: int):
        self.n = n
        self.parent = [-1] * n

    def root(self, x: int) -> int:
        """Find the root of x"""
        if self.parent[x] < 0:
            return x
        self.parent[x] = self.root(self.parent[x])
        return self.parent[x]

    def size(self, x: int) -> int:
        """size of the group"""
        return -self.parent[self.root(x)]

    def equiv(self, x: int, y: int) -> bool:
        """Check if x and y are in the same group"""
        return self.root(x) == self.root(y)

    def unite(self, x: int, y: int) -> bool:
        """Unite x and y"""
        x, y = self.root(x), self.root(y)
        if x == y:
            return False
        if -self.parent[x] < -self.parent[y]:
            x, y = y, x
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        return True


class IsSortingOk:
    """is_sorting_network Ok type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return True

    def __str__(self):
        return 'Yes'

    def get_data(self):
        """get data value"""
        return self.value


class IsSortingNg:
    """is_sorting_network Ng type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return False

    def __str__(self):
        return 'No'

    def get_error(self):
        """get error value"""
        return self.value


class Cmp:
    """Comparison part"""
    def __init__(self, root: int, cmp_part: list[tuple[int, int, int]]):
        self._data = (root, cmp_part)

    def data(self) -> tuple[int, list[tuple[int, int, int]]]:
        """root, cmp_part"""
        return self._data


class Combine:
    """Combine master and slave"""
    def __init__(self, master: int, slave: int):
        self._pair = (master, slave)

    def master_slave(self) -> tuple[int, int]:
        """master and slave"""
        return self._pair


def fib1(n: int) -> list[int]:
    """Generate Fibonacci sequence [1,1,2,3,…,Fib(n+1)]."""
    return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1])


def verify_strategy(n: int, net: list[tuple[int, int]]) -> list[Cmp | Combine]:
    """Generate processing order"""
    layered: list[bool] = [False] * len(net)
    layers: list[Cmp | Combine] = []
    skip_len: int = 0
    net_i: list[tuple[int, tuple[int, int]]] = list(enumerate(net))
    dsu = Dsu(n)
    while skip_len < len(net):
        net_checked = [False] * n
        net_layer = [[] for _ in range(n)]
        net_combine = (n + 1, 0, 0)
        for i, (a, b) in net_i[skip_len:]:
            if layered[i]:
                continue
            checked = net_checked[a] or net_checked[b]
            net_checked[a] = net_checked[b] = True
            if checked:
                continue
            if dsu.equiv(a, b):
                root_a = dsu.root(a)
                net_layer[root_a].append((i, a, b))
                layered[i] = True
            else:
                root_a, root_b = dsu.root(a), dsu.root(b)
                size_a, size_b = dsu.size(a), dsu.size(b)
                net_combine = min(net_combine, (size_a + size_b, root_a, root_b))
        if any(net_layer):
            # Append the comparison part
            for i, ces in enumerate(net_layer):
                if len(ces) == 0:
                    continue
                layers.append(Cmp(i, ces))
            for f in layered[skip_len:]:
                if f:
                    skip_len += 1
                else:
                    break
        else:
            # Append the combine part
            size, root_a, root_b = net_combine
            if size > n:
                break
            dsu.unite(root_a, root_b)
            root_master = dsu.root(root_a)
            root_slave = root_a ^ root_b ^ root_master
            layers.append(Combine(root_master, root_slave))
    return layers


def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:
    """
    Check if the given network is a sorting network.
    Runs in O(m * phi**n) time complexity. phi is the golden ratio 1.618...
    ref: wikipedia: Sorting network
    https://en.wikipedia.org/wiki/Sorting_network
    ref: Hisayasu Kuroda. (1997). A proposal of Gap Decrease Sorting Network.
    Trans.IPS.Japan, vol.38, no.3, p.381-389.
    http://id.nii.ac.jp/1001/00013442/
    ref: brianpursley/sorting-network, PR#9 three-valued-logic DFS approach
    https://github.com/brianpursley/sorting-network/pull/9
    ref: bertdobbelaere/SorterHunter, nw_tool.py
    https://github.com/bertdobbelaere/SorterHunter/blob/master/Util/nw_tool.py
    """
    assert 2 <= n
    # Check the range of 0-indexed input
    assert all(0 <= a < b < n for a, b in net)
    # Number of comparators
    m = len(net)
    # Initial state is all '?' = indeterminate: not determined to be 0 or 1
    states = [[(1 << i, 1 << i)] for i in range(n)]
    # Record whether the comparator is used
    unused = [True] * m
    # Record the position that is not sorted
    unsorted_i = 0
    # Initialize DSU
    dsu = Dsu(n)
    # Execute search for each layer
    for job in verify_strategy(n, net):
        if isinstance(job, Combine):
            master, slave = job.master_slave()
            size_master, size_slave = dsu.size(master), dsu.size(slave)
            dsu.unite(master, slave)
            len_master, len_slave = len(states[master]), len(states[slave])
            states[master], states[slave] = [(sz | mz, so | mo) for sz, so in states[slave] for mz, mo in states[master]], []
            if PROGRESS_THRESHOLD <= n:
                sys.stderr.write(f'Combining, size: {size_master}+{size_slave}=>{size_master + size_slave}, len: {len_master}*{len_slave}=>{len_master * len_slave}, root_master: {master}, root_slave: {slave}\n')
        elif isinstance(job, Cmp):
            root, cmp_part = job.data()
            len_pre = len(states[root])
            # Apply comparators, Generate branch
            stack = []
            for i, (z, o) in enumerate(states[root]):
                for j, (cei, a, b) in enumerate(cmp_part):
                    if (1 & (o >> a) & (z >> b)) == 0:
                        continue
                    if (1 & (z >> a) & (o >> b)) == 0:
                        unused[cei] = False
                        xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
                        z ^= (xz << a) | (xz << b)
                        o ^= (xo << a) | (xo << b)
                    else:
                        unused[cei] = False
                        stack.append((j, z, o ^ (1 << a) ^ (1 << b)))
                        z ^= 1 << b
                states[root][i] = (z, o)
            # Apply comparators, Branch processing
            states_next: set[tuple[int, int]] = set()
            while stack:
                i, z, o = stack.pop()
                while i < len(cmp_part):
                    cei, a, b = cmp_part[i]
                    i += 1
                    if (1 & (o >> a) & (z >> b)) == 0:
                        continue
                    if (1 & (z >> a) & (o >> b)) == 0:
                        unused[cei] = False
                        xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
                        z ^= (xz << a) | (xz << b)
                        o ^= (xo << a) | (xo << b)
                    else:
                        unused[cei] = False
                        stack.append((i, z, o ^ (1 << a) ^ (1 << b)))
                        z ^= 1 << b
                states_next.add((z, o))
            del stack
            len_gen = len(states[root]) + len(states_next)
            # Write back the next state
            states[root] = list(set(states[root]).union(states_next))
            del states_next
            len_dedup = len(states[root])
            if PROGRESS_THRESHOLD <= n:
                sys.stderr.write(f'AppliedCE, size: {dsu.size(root)}, len: {len_pre}=>{len_gen}=>{len_dedup}, root: {root}, cmp: {str(cmp_part)}\n')
    # Check if all branches are sorted
    for queue in states:
        n1_mask = (1 << (n - 1)) - 1
        q_mask = (queue[0][0] | queue[0][1]) if queue else 0
        unsorted_i |= (q_mask & (~q_mask >> 1)) & n1_mask
        for z, o in queue:
            unsorted_i |= (o & (z >> 1))
    # If there are unsorted branches
    if unsorted_i != 0:
        unsorted = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)]
        return IsSortingNg(unsorted)
    # If all branches are sorted
    return IsSortingOk(unused)


# Golden ratio (1+sqrt(5))/2 ≒ 1.618033988749895
PHI = math.sqrt(1.25) + 0.5  # Golden ratio


def main():
    """Input and output processing for test cases"""
    t = int(sys.stdin.readline())
    assert t <= MAX_TESTCASES or UNLIMITED
    cost = 0
    for _ in range(t):
        n, m = map(int, sys.stdin.readline().split())
        assert 2 <= n <= MAX_N or UNLIMITED
        assert 1 <= m <= n * (n - 1) // 2 or UNLIMITED
        cost += m * PHI**n  # Computational cost of test cases
        assert cost <= MAX_COST or UNLIMITED
        # 1-indexed -> 0-indexed
        a = list(map(lambda x: int(x) - 1, sys.stdin.readline().split()))
        b = list(map(lambda x: int(x) - 1, sys.stdin.readline().split()))
        cmps: list[tuple[int, int]] = list(zip(a, b))
        assert len(cmps) == m
        assert all(0 <= a < b < n for a, b in cmps)
        # is_sorting: Check if the given network is a sorting network
        start = time.time()
        is_sorting = is_sorting_network(n, cmps)
        end = time.time()
        sys.stderr.write(f'{end - start:.6f}[s]\n')
        print(is_sorting)  # Yes or No
        if is_sorting:
            # unused_cmp: Whether the comparator is unused (only if is_sorting=True)
            unused_cmp = is_sorting.get_data()
            assert len(unused_cmp) == m
            print(sum(unused_cmp))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
        else:
            # unsorted_pos: Positions that may not be sorted (only if is_sorting=False)
            unsorted_pos = is_sorting.get_error()
            assert len(unsorted_pos) == n - 1
            print(sum(unsorted_pos))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))


main()
0