結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-01-09 15:14:32
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 846 ms / 2,000 ms
コード長 8,570 bytes
コンパイル時間 434 ms
コンパイル使用メモリ 82,508 KB
実行使用メモリ 79,796 KB
最終ジャッジ日時 2025-03-05 20:30:34
合計ジャッジ時間 33,106 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

"""
yukicoder Problem: Verify Sorting Network
"""
import functools
import math
import sys


SHOW_PROGRESS = False
UNLIMITED = False
MAX_TESTCASES = 1000
MAX_N = 27
MAX_COST = 1e8
STATE_ZERO = 0
STATE_ONE = 1
STATE_UNKNWON = 2


class IsSortingOk:
    """is_sorting_network Ok type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return True

    def __str__(self):
        return 'Yes'

    def get_data(self):
        """get data value"""
        return self.value


class IsSortingNg:
    """is_sorting_network Ng type"""
    def __init__(self, value: list[bool]):
        self.value = value

    def __bool__(self):
        return False

    def __str__(self):
        return 'No'

    def get_error(self):
        """get error value"""
        return self.value


def is_sorting_network(n: int, net: list[tuple[int, int]], show_progress=False) -> IsSortingOk | IsSortingNg:
    """
    与えられたネットワークが sorting network であるかどうかを調べます。
    時間計算量 O(m * φ**n) で動作します。 φ は黄金比1.618...です。
    """
    assert 2 <= n
    # 0-indexed 入力の範囲を確認
    assert all(0 <= a < b < n for a, b in net)
    # 比較器の数
    m = len(net)
    # 初期状態はすべて '?' = 不定: 0 または 1 に決定されていない
    stack = [(0, [STATE_UNKNWON] * n, 0, n - 1)]
    # 使われる事がある比較器かどうかを記録
    unused = [True] * m
    # ソートされない位置を記録
    unsorted = [False] * (n - 1)

    def fib1(n: int) -> list[int]:
        """フィボナッチ数列 [1,1,2,3,…,Fib(n+1)] を生成します。"""
        return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1])

    # フィボナッチ数列を生成
    fib = fib1(n)
    # 探索枝の進捗: 0 から fib[n] まで
    progress, next_progress = 0, 0
    # 'a: 分岐スタックのループ
    while stack:
        # 分岐スタックを取得
        # i: 比較器のインデックス
        # p: 状態ベクトル: 各ラインの状態 (ZERO:'0'=0 ≤ UNKNOWN:'?'=不定 ≤ ONE:'1'=1)
        # z: 先頭の非0位置
        # o: 末尾の非1位置
        i, p, z, o = stack.pop()
        # 'b: 比較器のループ
        while i < m:
            # 比較器を取得
            a, b = net[i]
            i += 1
            # すべての分岐で p が '0...01...1' または '0...0?1...1' にソートされるかどうかを確認します。
            if p[a] == STATE_ZERO or p[b] == STATE_ONE:
                pass  # p[a]=='0' もしくは p[b]=='1' の場合、何もしない
            elif p[a] == STATE_UNKNWON and p[b] == STATE_UNKNWON:
                # 入力(p[a],p[b]) == (?,?) の場合、
                # (p[a],p[b]) が (1,0) となる可能性を内包しているため、交換が起きる入力が有り得る
                unused[i - 1] = False
                # 3値論理の組1つでは比較交換器の出力を正確に表現できないため分岐を作ります:
                # 2値論理で表現すると出力は: {(0,0),(0,1),(1,1)}
                # 3値論理で表現すると: {(0,0),(?,1)} または {(0,?),(1,1)}
                # この実装では {(0,0),(?,1)} の分岐を作ります。
                q = p.copy()  # 分岐用に状態をコピー
                # 分岐 (?,?) → (0,0),(?,1)
                q[a], q[b], p[b] = STATE_ZERO, STATE_ZERO, STATE_ONE
                # 'q' の先頭の非0位置を確認
                for j in range(z, o):
                    if q[j] >= 0:
                        # 'q' がまだソートされていない場合
                        stack.append((i, q, j, o))
                        break  # 'b: 次の比較器へ
                else:
                    progress += 1
                # p の末尾の非1位置を確認
                for j in range(o, z, -1):
                    if p[j] < n:
                        # p がまだソートされていない場合
                        o = j
                        break  # 'b: 次の比較器へ
                else:
                    # この分岐で p がソートされている場合:
                    progress += 1
                    break  # 'a: 次の分岐へ
            else:
                # p[a]!='0' かつ p[b]!='1' かつ (p[a],p[b])!=('?','?') の場合
                # (p[a],p[b]) が (1,0) となる可能性を内包しているため、交換が起きる入力が有り得る
                unused[i - 1] = False
                # (p[a],p[b]) が [(?,0),(1,0),(1,?)] の場合、
                # (p[a],p[b]) → (p[b],p[a]) のように交換します。
                p[a], p[b] = p[b], p[a]
                # p の先頭の非0位置を確認
                for j in range(z, o):
                    if p[j] >= 0:
                        # p がまだソートされていない場合
                        z = j
                        break  # 'b: 次の比較器へ
                else:
                    # この分岐で p がソートされている場合:
                    progress += 1
                    break  # 'a: 次の分岐へ
                # p の末尾の非1位置を確認
                for j in range(o, z, -1):
                    if p[j] < n:
                        # p がまだソートされていない場合
                        o = j
                        break  # 'b: 次の比較器へ
                else:
                    # この分岐で p がソートされている場合:
                    progress += 1
                    break  # 'a: 次の分岐へ
        else:
            # すべての比較器を使用しても p がソートされていない分岐がある場合
            # ソートされない位置を記録
            for j in range(n - 1):
                if p[j] != STATE_ZERO and p[j + 1] != STATE_ONE:
                    unsorted[j] = True
            # 残り未知数に応じた進捗を加算
            progress += fib[p.count(STATE_UNKNWON)]
        # 進捗を表示
        if show_progress and progress >= next_progress:
            percent = progress * 100 // fib[n]
            sys.stderr.write(f'{percent}%\r')
            # 1% 進むごとに更新: ceil((percent + 1) * fib[n] / 100)
            next_progress = ((percent + 1) * fib[n] - 1) // 100 + 1
    if show_progress:
        sys.stderr.write('\n')
    # 探索枝の数がフィボナッチ数列の値と一致することを確認
    assert progress == fib[n]
    # ソートされていない分岐がある場合
    if any(unsorted):
        return IsSortingNg(unsorted)
    # すべての分岐でソートされている場合
    return IsSortingOk(unused)


# 黄金比 (1+sqrt(5))/2 ≒ 1.618033988749895
PHI = math.sqrt(1.25) + 0.5  # 黄金比


def main():
    """テストケースの入出力処理"""
    t = int(sys.stdin.readline())
    assert t <= MAX_TESTCASES or UNLIMITED
    cost = 0
    for _ in range(t):
        n, m = map(int, sys.stdin.readline().split())
        assert 2 <= n <= MAX_N or UNLIMITED
        assert 1 <= m <= n * (n - 1) // 2 or UNLIMITED
        cost += m * PHI**n  # テストケースの計算量コスト
        assert cost <= MAX_COST or UNLIMITED
        # 1-indexed -> 0-indexed
        a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
        cmps: list[tuple[int, int]] = list(zip(a, b))
        assert len(cmps) == m
        assert all(0 <= a < b < n for a, b in cmps)
        # is_sorting: 与えられたネットワークが sorting network であるかどうか
        is_sorting = is_sorting_network(n, cmps, show_progress=SHOW_PROGRESS)
        print(is_sorting)  # Yes or No
        if is_sorting:
            # unused_cmp: 使われない比較器かどうか (is_sorting=True の場合のみ)
            unused_cmp = is_sorting.get_data()
            assert len(unused_cmp) == m
            print(sum(unused_cmp))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
        else:
            # unsorted_pos: ソートされない可能性のある位置 (is_sorting=False の場合のみ)
            unsorted_pos = is_sorting.get_error()
            assert len(unsorted_pos) == n - 1
            print(sum(unsorted_pos))
            print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))


main()
0