結果
問題 |
No.3047 Verification of Sorting Network
|
ユーザー |
👑 |
提出日時 | 2025-03-11 21:11:43 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,202 ms / 2,000 ms |
コード長 | 5,890 bytes |
コンパイル時間 | 425 ms |
コンパイル使用メモリ | 82,296 KB |
実行使用メモリ | 141,464 KB |
最終ジャッジ日時 | 2025-03-11 21:12:14 |
合計ジャッジ時間 | 30,161 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
ソースコード
""" yukicoder Problem: Verify Sorting Network """ import functools import sys import time SHOW_PROGRESS = True PROGRESS_THRESHOLD = 28 DICT_THRESHOLD = 20 class Dsu: """Disjoint Set Union by size""" def __init__(self, n: int): self.n = n self.parent = [-1] * n def root(self, x: int) -> int: """Find the root of x""" if self.parent[x] < 0: return x self.parent[x] = self.root(self.parent[x]) return self.parent[x] def unite(self, x: int, y: int) -> bool: """Unite x and y""" x, y = self.root(x), self.root(y) if x == y: return False if -self.parent[x] < -self.parent[y]: x, y = y, x self.parent[x] += self.parent[y] self.parent[y] = x return True def size(self, x: int) -> int: """Return the size of the group containing x""" return -self.parent[self.root(x)] class IsSortingOk: """is_sorting_network Ok type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return True def __str__(self): return 'Yes' def get_data(self): """get data value""" return self.value class IsSortingNg: """is_sorting_network Ng type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return False def __str__(self): return 'No' def get_error(self): """get error value""" return self.value def fib1(n: int) -> list[int]: """Generates the Fibonacci sequence [1,1,2,3,…,Fib(n+1)].""" return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1]) def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg: """ Checks if the given network is a sorting network. Operates in time complexity O(m * phi**n). phi is the golden ratio 1.618... """ assert 2 <= n # Check the range of 0-indexed inputs assert all(0 <= a < b < n for a, b in net) # Fibonacci sequence [1,1,2,3,…,Fib(n+1)] fib1n = fib1(n) # Number of comparators m = len(net) # Record whether the comparator is ever used unused_cmp = [] # Record unsorted positions unsorted_i = 0 # show_progress show_progress = SHOW_PROGRESS and n >= PROGRESS_THRESHOLD # Initial state is all '?' = indeterminate: not determined to be 0 or 1 dsu = Dsu(n) status: list[list[tuple[int, int]]] = [[(1 << i, 1 << i)] for i in range(n)] for i, (a, b) in enumerate(net): root_a, root_b = dsu.root(a), dsu.root(b) dsu.unite(a, b) root_master = dsu.root(a) root_slave = root_a ^ root_b ^ root_master set_par = set() unused_f = True if root_a != root_b: status[root_master] = [ (sz | mz, so | mo) for sz, so in status[root_slave] for mz, mo in status[root_master] ] status[root_slave] = [(0, 0)] for z, o in status[root_master]: if ((o >> a) & 1) == 0 or ((z >> b) & 1) == 0: set_par.add((z, o)) elif ((z >> a) & 1) == 0 or ((o >> b) & 1) == 0: unused_f = False xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1) z, o = (z ^ ((xz << a) | (xz << b))), (o ^ ((xo << a) | (xo << b))) set_par.add((z, o)) else: unused_f = False qz, qo, z = z, (o ^ (1 << a) ^ (1 << b)), (z ^ (1 << b)) set_par.add((qz, qo)) set_par.add((z, o)) unused_cmp.append(unused_f) status[root_master] = list(set_par) assert len(status[root_master]) <= fib1n[dsu.size(root_master)] if show_progress: percent = (i + 1) * 100 // m sys.stderr.write(f'{percent}%\r') for queue in status: n1_mask = (1 << (n - 1)) - 1 q_mask = (queue[0][0] | queue[0][1]) if queue else 0 unsorted_i |= ((q_mask & (~q_mask >> 1))) & n1_mask for z, o in queue: unsorted_i |= (o & (z >> 1)) if show_progress: sys.stderr.write('\n') # If there are unsorted branches if unsorted_i != 0: unsorted_pos = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)] return IsSortingNg(unsorted_pos) # If all branches are sorted return IsSortingOk(unused_cmp) def main(): """Input and output processing for test cases""" start = time.time() t = int(sys.stdin.readline()) for _ in range(t): n, m = map(int, sys.stdin.readline().split()) # 1-indexed -> 0-indexed a = map(lambda x: int(x) - 1, sys.stdin.readline().split()) b = map(lambda x: int(x) - 1, sys.stdin.readline().split()) cmps: list[tuple[int, int]] = list(zip(a, b)) assert len(cmps) == m assert all(0 <= a < b < n for a, b in cmps) # is_sorting: Check if the given network is a sorting network is_sorting = is_sorting_network(n, cmps) print(is_sorting) # Yes or No if is_sorting: # unused_cmp: Whether the comparator is unused (only if is_sorting=True) unused_cmp = is_sorting.get_data() assert len(unused_cmp) == m print(sum(unused_cmp)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp)))) else: # unsorted_pos: Positions that may not be sorted (only if is_sorting=False) unsorted_pos = is_sorting.get_error() assert len(unsorted_pos) == n - 1 print(sum(unsorted_pos)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos)))) end = time.time() sys.stderr.write(f'{end - start:.6f}[s]\n') main()