結果
問題 |
No.3047 Verification of Sorting Network
|
ユーザー |
👑 |
提出日時 | 2025-03-19 14:34:40 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,102 ms / 2,000 ms |
コード長 | 10,819 bytes |
コンパイル時間 | 371 ms |
コンパイル使用メモリ | 82,104 KB |
実行使用メモリ | 96,908 KB |
最終ジャッジ日時 | 2025-03-19 14:35:06 |
合計ジャッジ時間 | 23,558 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
ソースコード
""" yukicoder Problem: Verify Sorting Network """ import functools import math import sys import time PROGRESS_THRESHOLD = 28 GC_THRESHOLD = 1000000 UNLIMITED = True MAX_TESTCASES = 1000 MAX_N = 27 MAX_COST = 1e8 class Dsu: """Disjoint Set Union by size""" def __init__(self, n: int): self.n = n self.parent = [-1] * n def root(self, x: int) -> int: """Find the root of x""" if self.parent[x] < 0: return x self.parent[x] = self.root(self.parent[x]) return self.parent[x] def size(self, x: int) -> int: """size of the group""" return -self.parent[self.root(x)] def equiv(self, x: int, y: int) -> bool: """Check if x and y are in the same group""" return self.root(x) == self.root(y) def unite(self, x: int, y: int) -> bool: """Unite x and y""" x, y = self.root(x), self.root(y) if x == y: return False if -self.parent[x] < -self.parent[y]: x, y = y, x self.parent[x] += self.parent[y] self.parent[y] = x return True class IsSortingOk: """is_sorting_network Ok type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return True def __str__(self): return 'Yes' def get_data(self): """get data value""" return self.value class IsSortingNg: """is_sorting_network Ng type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return False def __str__(self): return 'No' def get_error(self): """get error value""" return self.value class Cmp: """Comparison part""" def __init__(self, root: int, cmp_part: list[tuple[int, int, int]]): self._data = (root, cmp_part) def data(self) -> tuple[int, list[tuple[int, int, int]]]: """root, cmp_part""" return self._data class Combine: """Combine master and slave""" def __init__(self, master: int, slave: int): self._pair = (master, slave) def master_slave(self) -> tuple[int, int]: """master and slave""" return self._pair def fib1(n: int) -> list[int]: """Generate Fibonacci sequence [1,1,2,3,…,Fib(n+1)].""" return functools.reduce(lambda x, _: x + [sum(x[-2:])], range(n), [1]) def verify_strategy(n: int, net: list[tuple[int, int]]) -> list[Cmp | Combine]: """Generate processing order""" layered: list[bool] = [False] * len(net) layers: list[Cmp | Combine] = [] skip_len: int = 0 net_i: list[tuple[int, tuple[int, int]]] = list(enumerate(net)) dsu = Dsu(n) while skip_len < len(net): net_checked = [False] * n net_layer = [[] for _ in range(n)] net_combine = (n + 1, 0, 0) for i, (a, b) in net_i[skip_len:]: if layered[i]: continue checked = net_checked[a] or net_checked[b] net_checked[a] = net_checked[b] = True if checked: continue if dsu.equiv(a, b): root_a = dsu.root(a) net_layer[root_a].append((i, a, b)) layered[i] = True else: root_a, root_b = dsu.root(a), dsu.root(b) size_a, size_b = dsu.size(a), dsu.size(b) net_combine = min(net_combine, (size_a + size_b, root_a, root_b)) if any(net_layer): # Append the comparison part for i, ces in enumerate(net_layer): if len(ces) == 0: continue layers.append(Cmp(i, ces)) for f in layered[skip_len:]: if f: skip_len += 1 else: break else: # Append the combine part size, root_a, root_b = net_combine if size > n: break dsu.unite(root_a, root_b) root_master = dsu.root(root_a) root_slave = root_a ^ root_b ^ root_master layers.append(Combine(root_master, root_slave)) return layers def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg: """ Check if the given network is a sorting network. Runs in O(m * phi**n) time complexity. phi is the golden ratio 1.618... ref: wikipedia: Sorting network https://en.wikipedia.org/wiki/Sorting_network ref: Hisayasu Kuroda. (1997). A proposal of Gap Decrease Sorting Network. Trans.IPS.Japan, vol.38, no.3, p.381-389. http://id.nii.ac.jp/1001/00013442/ ref: brianpursley/sorting-network, PR#9 three-valued-logic DFS approach https://github.com/brianpursley/sorting-network/pull/9 ref: bertdobbelaere/SorterHunter, nw_tool.py https://github.com/bertdobbelaere/SorterHunter/blob/master/Util/nw_tool.py """ assert 2 <= n # Check the range of 0-indexed input assert all(0 <= a < b < n for a, b in net) # Number of comparators m = len(net) # Initial state is all '?' = indeterminate: not determined to be 0 or 1 states = [[(1 << i, 1 << i)] for i in range(n)] # Record whether the comparator is used unused = [True] * m # Record the position that is not sorted unsorted_i = 0 # Initialize DSU dsu = Dsu(n) # Execute search for each layer for job in verify_strategy(n, net): if isinstance(job, Combine): master, slave = job.master_slave() size_master, size_slave = dsu.size(master), dsu.size(slave) dsu.unite(master, slave) len_master, len_slave = len(states[master]), len(states[slave]) states[master], states[slave] = [(sz | mz, so | mo) for sz, so in states[slave] for mz, mo in states[master]], [] if PROGRESS_THRESHOLD <= n: sys.stderr.write(f'Combining, size: {size_master}+{size_slave}=>{size_master + size_slave}, len: {len_master}*{len_slave}=>{len_master * len_slave}, root_master: {master}, root_slave: {slave}\n') elif isinstance(job, Cmp): root, cmp_part = job.data() len_pre = len(states[root]) # Apply comparators, Generate branch stack: list[list[tuple[int, int]]] = [[] for _ in range(len(cmp_part) + 1)] for i, (z, o) in enumerate(states[root]): for j, (cei, a, b) in enumerate(cmp_part): if (1 & (o >> a) & (z >> b)) == 0: continue if (1 & (z >> a) & (o >> b)) == 0: unused[cei] = False xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1) z ^= (xz << a) | (xz << b) o ^= (xo << a) | (xo << b) else: unused[cei] = False stack[j + 1].append((z, o ^ (1 << a) ^ (1 << b))) z ^= 1 << b states[root][i] = (z, o) # Apply comparators, Branch processing for i, st in enumerate(stack): cmp_part_i = cmp_part[i:] if not cmp_part_i: break while st: z, o = st.pop() j = i for cei, a, b in cmp_part_i: j += 1 if (1 & (o >> a) & (z >> b)) == 0: continue if (1 & (z >> a) & (o >> b)) == 0: unused[cei] = False xz, xo = (((z >> a) ^ (z >> b)) & 1), (((o >> a) ^ (o >> b)) & 1) z ^= (xz << a) | (xz << b) o ^= (xo << a) | (xo << b) else: unused[cei] = False stack[j].append((z, o ^ (1 << a) ^ (1 << b))) z ^= 1 << b stack[-1].append((z, o)) # Write back the next state, Deduplication len_gen = len(states[root]) + len(stack[-1]) states[root] = list(set(states[root] + stack[-1])) del stack len_dedup = len(states[root]) if PROGRESS_THRESHOLD <= n: sys.stderr.write(f'AppliedCE, size: {dsu.size(root)}, len: {len_pre}=>{len_gen}=>{len_dedup}, root: {root}, cmp: {str(cmp_part)}\n') # Check if all branches are sorted for queue in states: n1_mask = (1 << (n - 1)) - 1 q_mask = (queue[0][0] | queue[0][1]) if queue else 0 unsorted_i |= (q_mask & (~q_mask >> 1)) & n1_mask for z, o in queue: unsorted_i |= (o & (z >> 1)) # If there are unsorted branches if unsorted_i != 0: unsorted = [((unsorted_i >> i) & 1) != 0 for i in range(n - 1)] return IsSortingNg(unsorted) # If all branches are sorted return IsSortingOk(unused) # Golden ratio (1+sqrt(5))/2 ≒ 1.618033988749895 PHI = math.sqrt(1.25) + 0.5 # Golden ratio def main(): """Input and output processing for test cases""" t = int(sys.stdin.readline()) assert t <= MAX_TESTCASES or UNLIMITED cost = 0 for _ in range(t): n, m = map(int, sys.stdin.readline().split()) assert 2 <= n <= MAX_N or UNLIMITED assert 1 <= m <= n * (n - 1) // 2 or UNLIMITED cost += m * PHI**n # Computational cost of test cases assert cost <= MAX_COST or UNLIMITED # 1-indexed -> 0-indexed a = list(map(lambda x: int(x) - 1, sys.stdin.readline().split())) b = list(map(lambda x: int(x) - 1, sys.stdin.readline().split())) cmps: list[tuple[int, int]] = list(zip(a, b)) assert len(cmps) == m assert all(0 <= a < b < n for a, b in cmps) # is_sorting: Check if the given network is a sorting network start = time.time() is_sorting = is_sorting_network(n, cmps) end = time.time() sys.stderr.write(f'{end - start:.6f}[s]\n') print(is_sorting) # Yes or No if is_sorting: # unused_cmp: Whether the comparator is unused (only if is_sorting=True) unused_cmp = is_sorting.get_data() assert len(unused_cmp) == m print(sum(unused_cmp)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp)))) else: # unsorted_pos: Positions that may not be sorted (only if is_sorting=False) unsorted_pos = is_sorting.get_error() assert len(unsorted_pos) == n - 1 print(sum(unsorted_pos)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos)))) main()