結果
| 問題 |
No.3047 Verification of Sorting Network
|
| ユーザー |
👑 |
| 提出日時 | 2025-03-13 04:23:21 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,168 ms / 2,000 ms |
| コード長 | 10,507 bytes |
| コンパイル時間 | 657 ms |
| コンパイル使用メモリ | 82,900 KB |
| 実行使用メモリ | 100,336 KB |
| 最終ジャッジ日時 | 2025-03-13 04:23:50 |
| 合計ジャッジ時間 | 27,501 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 61 |
ソースコード
"""
yukicoder Problem: Verify Sorting Network
"""
import functools
import math
import sys
import time
import gc
PROGRESS_THRESHOLD = 28
GC_THRESHOLD = 1000000
UNLIMITED = True
MAX_TESTCASES = 1000
MAX_N = 27
MAX_COST = 1e8
class Dsu:
"""Disjoint Set Union by size"""
def __init__(self, n: int):
self.n = n
self.parent = [-1] * n
def root(self, x: int) -> int:
"""Find the root of x"""
if self.parent[x] < 0:
return x
self.parent[x] = self.root(self.parent[x])
return self.parent[x]
def size(self, x: int) -> int:
"""size of the group"""
return -self.parent[self.root(x)]
def equiv(self, x: int, y: int) -> bool:
"""Check if x and y are in the same group"""
return self.root(x) == self.root(y)
def unite(self, x: int, y: int) -> bool:
"""Unite x and y"""
x, y = self.root(x), self.root(y)
if x == y:
return False
if -self.parent[x] < -self.parent[y]:
x, y = y, x
self.parent[x] += self.parent[y]
self.parent[y] = x
return True
class IsSortingOk:
"""is_sorting_network Ok type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return True
def __str__(self):
return 'Yes'
def get_data(self):
"""get data value"""
return self.value
class IsSortingNg:
"""is_sorting_network Ng type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return False
def __str__(self):
return 'No'
def get_error(self):
"""get error value"""
return self.value
class Cmp:
"""Comparison part"""
def __init__(self, root: int, cmp_part: list[tuple[int, int, int]]):
self._root = root
self._cmp_part = cmp_part
def root(self) -> int:
"""group root node"""
return self._root
def cmp_part(self) -> list[tuple[int, int, int]]:
"""comparison part"""
return self._cmp_part
class Combine:
"""Combine master and slave"""
def __init__(self, master: int, slave: int):
self._master = master
self._slave = slave
def master_slave(self) -> tuple[int, int]:
"""master and slave"""
return self._master, self._slave
def fib1(n: int) -> list[int]:
"""Generate Fibonacci sequence [1,1,2,3,…,Fib(n+1)]."""
return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1])
def net_layers(n: int, net: list[tuple[int, int]]) -> list[Cmp | Combine]:
"""Split network into layers."""
layered: list[bool] = [False] * len(net)
layers: list[Cmp | Combine] = []
skip: int = 0
net_i: list[tuple[int, tuple[int, int]]] = list(enumerate(net))
dsu = Dsu(n)
while skip < len(net):
net_checked = [False] * n
net_layer = [[] for _ in range(n)]
net_combine = (n + 1, 0, 0)
for i, (a, b) in net_i[skip:]:
if layered[i]:
continue
checked = net_checked[a] or net_checked[b]
net_checked[a] = net_checked[b] = True
if checked:
continue
if dsu.equiv(a, b):
root_a = dsu.root(a)
net_layer[root_a].append((i, a, b))
layered[i] = True
else:
root_a, root_b = dsu.root(a), dsu.root(b)
size_a, size_b = dsu.size(a), dsu.size(b)
net_combine = min(net_combine, (size_a + size_b, root_a, root_b))
if all(len(x) == 0 for x in net_layer):
size, root_a, root_b = net_combine
if size > n:
break
dsu.unite(root_a, root_b)
root_master = dsu.root(root_a)
root_slave = root_a ^ root_b ^ root_master
layers.append(Combine(root_master, root_slave))
else:
for i, ces in enumerate(net_layer):
if len(ces) == 0:
continue
layers.append(Cmp(i, ces))
for i, f in enumerate(layered):
if f:
skip = i + 1
else:
break
return layers
def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:
"""
Check if the given network is a sorting network.
Runs in O(m * phi**n) time complexity. phi is the golden ratio 1.618...
ref: wikipedia: Sorting network
https://en.wikipedia.org/wiki/Sorting_network
ref: Hisayasu Kuroda. (1997). A proposal of Gap Decrease Sorting Network.
Trans.IPS.Japan, vol.38, no.3, p.381-389.
http://id.nii.ac.jp/1001/00013442/
ref: brianpursley/sorting-network, PR#9 three-valued-logic DFS approach
https://github.com/brianpursley/sorting-network/pull/9
"""
assert 2 <= n
# Check the range of 0-indexed input
assert all(0 <= a < b < n for a, b in net)
# Number of comparators
m = len(net)
# Initial state is all '?' = indeterminate: not determined to be 0 or 1
states = [[(1 << i, 1 << i)] for i in range(n)]
# Record whether the comparator is used
unused = [True] * m
# Record the position that is not sorted
unsorted_i = 0
# Execute search for each layer
dsu = Dsu(n)
for job in net_layers(n, net):
if isinstance(job, Combine):
master, slave = job.master_slave()
size_master, size_slave = dsu.size(master), dsu.size(slave)
dsu.unite(master, slave)
len_master, len_slave = len(states[master]), len(states[slave])
states[master] = [(sz | mz, so | mo) for sz, so in states[slave] for mz, mo in states[master]]
states[slave] = []
if GC_THRESHOLD <= len_master * len_slave:
gc.collect()
if PROGRESS_THRESHOLD <= n:
sys.stderr.write(f'Combining, size: {size_master}+{size_slave}=>{size_master + size_slave}, len: {len_master}*{len_slave}=>{len_master * len_slave}, root_master: {master}, root_slave: {slave}\n')
elif isinstance(job, Cmp):
root = job.root()
cmp_part = job.cmp_part()
len_pre = len(states[root])
states_next = set()
stack = []
for z, o in states[root]:
for i, (cei, a, b) in enumerate(cmp_part):
if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
pass
elif ((z >> a) & 1) == 0 or ((o >> b) & 1) == 0:
unused[cei] = False
xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
z ^= (xz << a) | (xz << b)
o ^= (xo << a) | (xo << b)
else:
unused[cei] = False
qz, qo = z, o ^ (1 << a) ^ (1 << b)
z ^= 1 << b
stack.append((i, qz, qo))
states_next.add((z, o))
while stack:
i, z, o = stack.pop()
while i < len(cmp_part):
cei, a, b = cmp_part[i]
i += 1
if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
pass
elif ((z >> a) & 1) == 0 or ((o >> b) & 1) == 0:
unused[cei] = False
xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
z ^= (xz << a) | (xz << b)
o ^= (xo << a) | (xo << b)
else:
unused[cei] = False
qz, qo = z, o ^ (1 << a) ^ (1 << b)
z ^= 1 << b
stack.append((i, qz, qo))
states_next.add((z, o))
states[root] = list(states_next)
len_dedup = len(states[root])
if GC_THRESHOLD <= max(len_pre, len_dedup):
gc.collect()
if PROGRESS_THRESHOLD <= n:
sys.stderr.write(f'AppliedCE, size: {dsu.size(root)}, len: {len_pre}=>{len_dedup}, root: {root}, cmp: {str(cmp_part)}\n')
for queue in states:
n1_mask = (1 << (n - 1)) - 1
q_mask = (queue[0][0] | queue[0][1]) if queue else 0
unsorted_i |= (q_mask & (~q_mask >> 1)) & n1_mask
for z, o in queue:
unsorted_i |= (o & (z >> 1))
# If there are unsorted branches
if unsorted_i != 0:
unsorted = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)]
return IsSortingNg(unsorted)
# If all branches are sorted
return IsSortingOk(unused)
# Golden ratio (1+sqrt(5))/2 ≒ 1.618033988749895
PHI = math.sqrt(1.25) + 0.5 # Golden ratio
def main():
"""Input and output processing for test cases"""
start = time.time()
t = int(sys.stdin.readline())
assert t <= MAX_TESTCASES or UNLIMITED
cost = 0
for _ in range(t):
n, m = map(int, sys.stdin.readline().split())
assert 2 <= n <= MAX_N or UNLIMITED
assert 1 <= m <= n * (n - 1) // 2 or UNLIMITED
cost += m * PHI**n # Computational cost of test cases
assert cost <= MAX_COST or UNLIMITED
# 1-indexed -> 0-indexed
a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
cmps: list[tuple[int, int]] = list(zip(a, b))
assert len(cmps) == m
assert all(0 <= a < b < n for a, b in cmps)
# is_sorting: Check if the given network is a sorting network
is_sorting = is_sorting_network(n, cmps)
print(is_sorting) # Yes or No
if is_sorting:
# unused_cmp: Whether the comparator is unused (only if is_sorting=True)
unused_cmp = is_sorting.get_data()
assert len(unused_cmp) == m
print(sum(unused_cmp))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
else:
# unsorted_pos: Positions that may not be sorted (only if is_sorting=False)
unsorted_pos = is_sorting.get_error()
assert len(unsorted_pos) == n - 1
print(sum(unsorted_pos))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))
end = time.time()
sys.stderr.write(f'{end - start:.6f}[s]\n')
main()