結果
| 問題 |
No.3047 Verification of Sorting Network
|
| ユーザー |
👑 |
| 提出日時 | 2025-03-11 21:07:31 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 6,596 bytes |
| コンパイル時間 | 568 ms |
| コンパイル使用メモリ | 82,848 KB |
| 実行使用メモリ | 145,412 KB |
| 最終ジャッジ日時 | 2025-03-11 21:08:02 |
| 合計ジャッジ時間 | 30,372 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | WA * 3 |
| other | AC * 12 WA * 49 |
ソースコード
"""
yukicoder Problem: Verify Sorting Network
"""
import functools
import sys
import time
SHOW_PROGRESS = True
PROGRESS_THRESHOLD = 28
DICT_THRESHOLD = 20
class Dsu:
"""Disjoint Set Union by size"""
def __init__(self, n: int):
self.n = n
self.parent = [-1] * n
def root(self, x: int) -> int:
"""Find the root of x"""
if self.parent[x] < 0:
return x
self.parent[x] = self.root(self.parent[x])
return self.parent[x]
def unite(self, x: int, y: int) -> bool:
"""Unite x and y"""
x, y = self.root(x), self.root(y)
if x == y:
return False
if -self.parent[x] < -self.parent[y]:
x, y = y, x
self.parent[x] += self.parent[y]
self.parent[y] = x
return True
def size(self, x: int) -> int:
"""Return the size of the group containing x"""
return -self.parent[self.root(x)]
class IsSortingOk:
"""is_sorting_network Ok type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return True
def __str__(self):
return 'Yes'
def get_data(self):
"""get data value"""
return self.value
class IsSortingNg:
"""is_sorting_network Ng type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return False
def __str__(self):
return 'No'
def get_error(self):
"""get error value"""
return self.value
def fib1(n: int) -> list[int]:
"""Generates the Fibonacci sequence [1,1,2,3,…,Fib(n+1)]."""
return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1])
def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:
"""
Checks if the given network is a sorting network.
Operates in time complexity O(m * phi**n). phi is the golden ratio 1.618...
"""
assert 2 <= n
# Check the range of 0-indexed inputs
assert all(0 <= a < b < n for a, b in net)
# Fibonacci sequence [1,1,2,3,…,Fib(n+1)]
fib1n = fib1(n)
# Number of comparators
m = len(net)
# Record whether the comparator is ever used
unused_cmp = []
# Record unsorted positions
unsorted_i = 0
# show_progress
show_progress = SHOW_PROGRESS and n >= PROGRESS_THRESHOLD
# Initial state is all '?' = indeterminate: not determined to be 0 or 1
dsu = Dsu(n)
status: list[list[tuple[int, int]]] = [[(1 << i, 1 << i)] for i in range(n)]
for i, (a, b) in enumerate(net):
root_a, root_b = dsu.root(a), dsu.root(b)
dsu.unite(a, b)
root_master = dsu.root(a)
root_slave = root_a ^ root_b ^ root_master
set_par = set()
unused_f = True
if root_a != root_b:
for sz, so in status[root_slave]:
for mz, mo in status[root_master]:
z, o = (sz | mz), (so | mo)
if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
set_par.add((z, o))
elif ((z >> a) & 1) == 0 or ((o >> b) & 1) == 0:
unused_f = False
xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
z, o = (z ^ ((xz << a) | (xz << b))), (o ^ ((xo << a) | (xo << b)))
set_par.add((z, o))
else:
unused_f = False
qz, qo, z = z, (o ^ (1 << a) ^ (1 << b)), (z ^ (1 << b))
set_par.add((qz, qo))
set_par.add((z, o))
status[root_slave] = [(0, 0)]
else:
for z, o in status[root_master]:
if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0:
set_par.add((z, o))
elif ((z >> a) & 1) == 0 or ((o >> b) & 1) == 0:
unused_f = False
xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1)
z, o = (z ^ ((xz << a) | (xz << b))), (o ^ ((xo << a) | (xo << b)))
set_par.add((z, o))
else:
unused_f = False
qz, qo, z = z, (o ^ (1 << a) ^ (1 << b)), (z ^ (1 << b))
set_par.add((qz, qo))
set_par.add((z, o))
unused_cmp.append(unused_f)
status[root_master] = list(set_par)
assert len(status[root_master]) <= fib1n[dsu.size(root_master)]
if show_progress:
percent = (i + 1) * 100 // m
sys.stderr.write(f'{percent}%\r')
for queue in status:
n1_mask = 1 << (n - 1)
q_mask = (queue[0][0] | queue[0][1]) if queue else 0
unsorted_i |= ((q_mask & (~q_mask >> 1))) & n1_mask
for z, o in queue:
unsorted_i |= (o & (z >> 1))
if show_progress:
sys.stderr.write('\n')
# If there are unsorted branches
if unsorted_i != 0:
unsorted_pos = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)]
return IsSortingNg(unsorted_pos)
# If all branches are sorted
return IsSortingOk(unused_cmp)
def main():
"""Input and output processing for test cases"""
start = time.time()
t = int(sys.stdin.readline())
for _ in range(t):
n, m = map(int, sys.stdin.readline().split())
# 1-indexed -> 0-indexed
a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
cmps: list[tuple[int, int]] = list(zip(a, b))
assert len(cmps) == m
assert all(0 <= a < b < n for a, b in cmps)
# is_sorting: Check if the given network is a sorting network
is_sorting = is_sorting_network(n, cmps)
print(is_sorting) # Yes or No
if is_sorting:
# unused_cmp: Whether the comparator is unused (only if is_sorting=True)
unused_cmp = is_sorting.get_data()
assert len(unused_cmp) == m
print(sum(unused_cmp))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
else:
# unsorted_pos: Positions that may not be sorted (only if is_sorting=False)
unsorted_pos = is_sorting.get_error()
assert len(unsorted_pos) == n - 1
print(sum(unsorted_pos))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))
end = time.time()
sys.stderr.write(f'{end - start:.6f}[s]\n')
main()