結果
| 問題 | No.3047 Verification of Sorting Network |
| ユーザー |
👑 |
| 提出日時 | 2026-04-20 07:12:26 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 408 ms / 2,000 ms |
| コード長 | 7,517 bytes |
| 記録 | |
| コンパイル時間 | 179 ms |
| コンパイル使用メモリ | 84,992 KB |
| 実行使用メモリ | 108,364 KB |
| 最終ジャッジ日時 | 2026-04-20 07:12:42 |
| 合計ジャッジ時間 | 14,067 ms |
|
ジャッジサーバーID (参考情報) |
judge2_0 / judge1_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 61 |
ソースコード
#!/usr/bin/env pypy3
import sys
from dataclasses import dataclass
MAX_TESTCASES = 1000
MAX_N = 27
class Dsu:
__slots__ = ("parent",)
def __init__(self, n: int):
self.parent = [-1] * n
def root(self, x: int) -> int:
parent = self.parent
while parent[x] >= 0:
px = parent[x]
if parent[px] >= 0:
parent[x] = parent[px]
x = px
return x
def size(self, x: int) -> int:
return -self.parent[self.root(x)]
def equiv(self, x: int, y: int) -> bool:
return self.root(x) == self.root(y)
def unite(self, x: int, y: int) -> bool:
parent = self.parent
x = self.root(x)
y = self.root(y)
if x == y:
return False
if parent[x] > parent[y]:
x, y = y, x
parent[x] += parent[y]
parent[y] = x
return True
@dataclass(frozen=True)
class IsSortingOk:
value: list[bool]
def __bool__(self):
return True
@dataclass(frozen=True)
class IsSortingNg:
value: list[bool]
def __bool__(self):
return False
@dataclass(frozen=True)
class Cmp:
root: int
cmp_part: list[tuple[int, int, int]]
@dataclass(frozen=True)
class Combine:
master: int
slave: int
def verify_strategy(n: int, net: list[tuple[int, int]]):
layered = [False] * len(net)
layers: list[Cmp | Combine] = []
skip_len = 0
net_i = list(enumerate(net))
dsu = Dsu(n)
while skip_len < len(net):
net_checked = [False] * n
net_layer: list[list[tuple[int, int, int]]] = [[] for _ in range(n)]
net_combine = (n + 1, 0, 0)
for i, (a, b) in net_i[skip_len:]:
if layered[i]:
continue
checked = net_checked[a] or net_checked[b]
net_checked[a] = True
net_checked[b] = True
if checked:
continue
if dsu.equiv(a, b):
root_a = dsu.root(a)
net_layer[root_a].append((i, a, b))
layered[i] = True
else:
root_a = dsu.root(a)
root_b = dsu.root(b)
cand = (dsu.size(a) + dsu.size(b), root_a, root_b)
if cand < net_combine:
net_combine = cand
if any(net_layer):
for i, ces in enumerate(net_layer):
if ces:
layers.append(Cmp(i, ces))
while skip_len < len(net) and layered[skip_len]:
skip_len += 1
else:
size, root_a, root_b = net_combine
if size > n:
break
dsu.unite(root_a, root_b)
root_master = dsu.root(root_a)
root_slave = root_a ^ root_b ^ root_master
layers.append(Combine(root_master, root_slave))
return layers
def is_sorting_network_high(n: int, net: list[tuple[int, int]]):
states = [[(1 << i, 1 << i)] for i in range(n)]
unused = [True] * len(net)
unsorted_i = 0
dsu = Dsu(n)
for job in verify_strategy(n, net):
if isinstance(job, Combine):
master, slave = job.master, job.slave
dsu.unite(master, slave)
sm = states[master]
ss = states[slave]
states[master] = [(sz | mz, so | mo) for sz, so in ss for mz, mo in sm]
states[slave] = []
else:
root = job.root
cmp_part = job.cmp_part
root_states = states[root]
stack: list[list[tuple[int, int]]] = [[] for _ in range(len(cmp_part) + 1)]
for idx, (z, o) in enumerate(root_states):
for j, (cei, a, b) in enumerate(cmp_part):
if ((o >> a) & (z >> b) & 1) == 0:
continue
if ((z >> a) & (o >> b) & 1) == 0:
unused[cei] = False
xz = ((z >> a) ^ (z >> b)) & 1
xo = ((o >> a) ^ (o >> b)) & 1
z ^= (xz << a) | (xz << b)
o ^= (xo << a) | (xo << b)
else:
unused[cei] = False
stack[j + 1].append((z, o ^ (1 << a) ^ (1 << b)))
z ^= 1 << b
root_states[idx] = (z, o)
for i, st in enumerate(stack[:-1]):
cmp_tail = cmp_part[i:]
while st:
z, o = st.pop()
j = i
for cei, a, b in cmp_tail:
j += 1
if ((o >> a) & (z >> b) & 1) == 0:
continue
if ((z >> a) & (o >> b) & 1) == 0:
unused[cei] = False
xz = ((z >> a) ^ (z >> b)) & 1
xo = ((o >> a) ^ (o >> b)) & 1
z ^= (xz << a) | (xz << b)
o ^= (xo << a) | (xo << b)
else:
unused[cei] = False
stack[j].append((z, o ^ (1 << a) ^ (1 << b)))
z ^= 1 << b
stack[-1].append((z, o))
root_states.extend(stack[-1])
states[root] = list(set(root_states))
n1_mask = (1 << (n - 1)) - 1
for queue in states:
q_mask = (queue[0][0] | queue[0][1]) if queue else 0
unsorted_i |= (q_mask & ((~q_mask) >> 1)) & n1_mask
for z, o in queue:
unsorted_i |= (o & (z >> 1))
if unsorted_i:
return IsSortingNg([((unsorted_i >> i) & 1) != 0 for i in range(n - 1)])
return IsSortingOk(unused)
def is_sorting_network_low(n: int, net: list[tuple[int, int]]):
m = len(net)
unused = [True] * m
unsorted = [False] * (n - 1)
pbits = 15
pfull = (1 << (1 << pbits)) - 1
lows: list[int] = []
for i in range(pbits):
le = ((1 << (1 << i)) - 1) << (1 << i)
for j in range(i + 1, pbits):
le |= le << (1 << j)
lows.append(le)
for i in range(1 << max(n - pbits, 0)):
p = lows + [(pfull if ((i >> j) & 1) else 0) for j in range(n - pbits)]
for j, (a, b) in enumerate(net):
na = p[a] & p[b]
if p[a] != na:
p[a], p[b] = na, p[a] | p[b]
unused[j] = False
for j in range(n - 1):
if (p[j] & ~p[j + 1]) != 0:
unsorted[j] = True
if any(unsorted):
return IsSortingNg(unsorted)
return IsSortingOk(unused)
def solve_one(n: int, m: int, a_line: list[int], b_line: list[int]) -> str:
net = [(a_line[i] - 1, b_line[i] - 1) for i in range(m)]
res = is_sorting_network_low(n, net) if n <= 18 else is_sorting_network_high(n, net)
if res:
arr = res.value
idx = [str(i + 1) for i, f in enumerate(arr) if f]
return "Yes\n{}\n{}".format(len(idx), " ".join(idx))
else:
arr = res.value
idx = [str(i + 1) for i, f in enumerate(arr) if f]
return "No\n{}\n{}".format(len(idx), " ".join(idx))
def main():
input = sys.stdin.buffer.readline
t = int(input())
for _ in range(t):
n, m = map(int, input().split())
a_line = list(map(int, input().split()))
b_line = list(map(int, input().split()))
sys.stdout.write(solve_one(n, m, a_line, b_line) + "\n")
if __name__ == "__main__":
main()