結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2026-04-20 07:12:26
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 408 ms / 2,000 ms
コード長 7,517 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 179 ms
コンパイル使用メモリ 84,992 KB
実行使用メモリ 108,364 KB
最終ジャッジ日時 2026-04-20 07:12:42
合計ジャッジ時間 14,067 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge1_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#!/usr/bin/env pypy3
import sys
from dataclasses import dataclass

MAX_TESTCASES = 1000
MAX_N = 27


class Dsu:
    __slots__ = ("parent",)

    def __init__(self, n: int):
        self.parent = [-1] * n

    def root(self, x: int) -> int:
        parent = self.parent
        while parent[x] >= 0:
            px = parent[x]
            if parent[px] >= 0:
                parent[x] = parent[px]
            x = px
        return x

    def size(self, x: int) -> int:
        return -self.parent[self.root(x)]

    def equiv(self, x: int, y: int) -> bool:
        return self.root(x) == self.root(y)

    def unite(self, x: int, y: int) -> bool:
        parent = self.parent
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        if parent[x] > parent[y]:
            x, y = y, x
        parent[x] += parent[y]
        parent[y] = x
        return True


@dataclass(frozen=True)
class IsSortingOk:
    value: list[bool]

    def __bool__(self):
        return True


@dataclass(frozen=True)
class IsSortingNg:
    value: list[bool]

    def __bool__(self):
        return False


@dataclass(frozen=True)
class Cmp:
    root: int
    cmp_part: list[tuple[int, int, int]]


@dataclass(frozen=True)
class Combine:
    master: int
    slave: int


def verify_strategy(n: int, net: list[tuple[int, int]]):
    layered = [False] * len(net)
    layers: list[Cmp | Combine] = []
    skip_len = 0
    net_i = list(enumerate(net))
    dsu = Dsu(n)

    while skip_len < len(net):
        net_checked = [False] * n
        net_layer: list[list[tuple[int, int, int]]] = [[] for _ in range(n)]
        net_combine = (n + 1, 0, 0)

        for i, (a, b) in net_i[skip_len:]:
            if layered[i]:
                continue
            checked = net_checked[a] or net_checked[b]
            net_checked[a] = True
            net_checked[b] = True
            if checked:
                continue
            if dsu.equiv(a, b):
                root_a = dsu.root(a)
                net_layer[root_a].append((i, a, b))
                layered[i] = True
            else:
                root_a = dsu.root(a)
                root_b = dsu.root(b)
                cand = (dsu.size(a) + dsu.size(b), root_a, root_b)
                if cand < net_combine:
                    net_combine = cand

        if any(net_layer):
            for i, ces in enumerate(net_layer):
                if ces:
                    layers.append(Cmp(i, ces))
            while skip_len < len(net) and layered[skip_len]:
                skip_len += 1
        else:
            size, root_a, root_b = net_combine
            if size > n:
                break
            dsu.unite(root_a, root_b)
            root_master = dsu.root(root_a)
            root_slave = root_a ^ root_b ^ root_master
            layers.append(Combine(root_master, root_slave))

    return layers


def is_sorting_network_high(n: int, net: list[tuple[int, int]]):
    states = [[(1 << i, 1 << i)] for i in range(n)]
    unused = [True] * len(net)
    unsorted_i = 0
    dsu = Dsu(n)

    for job in verify_strategy(n, net):
        if isinstance(job, Combine):
            master, slave = job.master, job.slave
            dsu.unite(master, slave)
            sm = states[master]
            ss = states[slave]
            states[master] = [(sz | mz, so | mo) for sz, so in ss for mz, mo in sm]
            states[slave] = []
        else:
            root = job.root
            cmp_part = job.cmp_part
            root_states = states[root]

            stack: list[list[tuple[int, int]]] = [[] for _ in range(len(cmp_part) + 1)]

            for idx, (z, o) in enumerate(root_states):
                for j, (cei, a, b) in enumerate(cmp_part):
                    if ((o >> a) & (z >> b) & 1) == 0:
                        continue
                    if ((z >> a) & (o >> b) & 1) == 0:
                        unused[cei] = False
                        xz = ((z >> a) ^ (z >> b)) & 1
                        xo = ((o >> a) ^ (o >> b)) & 1
                        z ^= (xz << a) | (xz << b)
                        o ^= (xo << a) | (xo << b)
                    else:
                        unused[cei] = False
                        stack[j + 1].append((z, o ^ (1 << a) ^ (1 << b)))
                        z ^= 1 << b
                root_states[idx] = (z, o)

            for i, st in enumerate(stack[:-1]):
                cmp_tail = cmp_part[i:]
                while st:
                    z, o = st.pop()
                    j = i
                    for cei, a, b in cmp_tail:
                        j += 1
                        if ((o >> a) & (z >> b) & 1) == 0:
                            continue
                        if ((z >> a) & (o >> b) & 1) == 0:
                            unused[cei] = False
                            xz = ((z >> a) ^ (z >> b)) & 1
                            xo = ((o >> a) ^ (o >> b)) & 1
                            z ^= (xz << a) | (xz << b)
                            o ^= (xo << a) | (xo << b)
                        else:
                            unused[cei] = False
                            stack[j].append((z, o ^ (1 << a) ^ (1 << b)))
                            z ^= 1 << b
                    stack[-1].append((z, o))

            root_states.extend(stack[-1])
            states[root] = list(set(root_states))

    n1_mask = (1 << (n - 1)) - 1
    for queue in states:
        q_mask = (queue[0][0] | queue[0][1]) if queue else 0
        unsorted_i |= (q_mask & ((~q_mask) >> 1)) & n1_mask
        for z, o in queue:
            unsorted_i |= (o & (z >> 1))

    if unsorted_i:
        return IsSortingNg([((unsorted_i >> i) & 1) != 0 for i in range(n - 1)])
    return IsSortingOk(unused)


def is_sorting_network_low(n: int, net: list[tuple[int, int]]):
    m = len(net)
    unused = [True] * m
    unsorted = [False] * (n - 1)

    pbits = 15
    pfull = (1 << (1 << pbits)) - 1

    lows: list[int] = []
    for i in range(pbits):
        le = ((1 << (1 << i)) - 1) << (1 << i)
        for j in range(i + 1, pbits):
            le |= le << (1 << j)
        lows.append(le)

    for i in range(1 << max(n - pbits, 0)):
        p = lows + [(pfull if ((i >> j) & 1) else 0) for j in range(n - pbits)]
        for j, (a, b) in enumerate(net):
            na = p[a] & p[b]
            if p[a] != na:
                p[a], p[b] = na, p[a] | p[b]
                unused[j] = False
        for j in range(n - 1):
            if (p[j] & ~p[j + 1]) != 0:
                unsorted[j] = True

    if any(unsorted):
        return IsSortingNg(unsorted)
    return IsSortingOk(unused)


def solve_one(n: int, m: int, a_line: list[int], b_line: list[int]) -> str:
    net = [(a_line[i] - 1, b_line[i] - 1) for i in range(m)]
    res = is_sorting_network_low(n, net) if n <= 18 else is_sorting_network_high(n, net)

    if res:
        arr = res.value
        idx = [str(i + 1) for i, f in enumerate(arr) if f]
        return "Yes\n{}\n{}".format(len(idx), " ".join(idx))
    else:
        arr = res.value
        idx = [str(i + 1) for i, f in enumerate(arr) if f]
        return "No\n{}\n{}".format(len(idx), " ".join(idx))


def main():
    input = sys.stdin.buffer.readline
    t = int(input())

    for _ in range(t):
        n, m = map(int, input().split())
        a_line = list(map(int, input().split()))
        b_line = list(map(int, input().split()))
        sys.stdout.write(solve_one(n, m, a_line, b_line) + "\n")


if __name__ == "__main__":
    main()
0