結果
問題 |
No.3047 Verification of Sorting Network
|
ユーザー |
👑 |
提出日時 | 2025-02-10 14:45:04 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 1,422 ms / 2,000 ms |
コード長 | 4,015 bytes |
コンパイル時間 | 382 ms |
コンパイル使用メモリ | 12,416 KB |
実行使用メモリ | 11,136 KB |
最終ジャッジ日時 | 2025-03-05 20:39:40 |
合計ジャッジ時間 | 58,730 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
ソースコード
""" yukicoder Problem: Verify Sorting Network O(m*2^n)解法 + 32768bit並列化 """ import math import sys MAX_TESTCASES = 10000 MAX_N = 27 MAX_COST = 1e8 STATE_ZERO = 0 STATE_ONE = 1 STATE_UNKNWON = 2 class IsSortingOk: """is_sorting_network Ok type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return True def __str__(self): return 'Yes' def get_data(self): """get data value""" return self.value class IsSortingNg: """is_sorting_network Ng type""" def __init__(self, value: list[bool]): self.value = value def __bool__(self): return False def __str__(self): return 'No' def get_error(self): """get error value""" return self.value def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg: """ 与えられたネットワークが sorting network であるかどうかを調べます。 Args: n: 入力の数 net: 比較器のリスト """ assert 2 <= n # 0-indexed 入力の範囲を確認 assert all(0 <= a < b < n for a, b in net) # 比較器の数 m = len(net) unused = [True] * m unsorted = [False] * (n - 1) # (2**pbits)-bit の状態ベクトルを並列に処理する pbits = 15 # (2**pbits)-bit の 1 が立っているビット列 pfull = (1 << (1 << pbits)) - 1 lows = [] for i in range(pbits): le = ((1 << (1 << i)) - 1) << (1 << i) for j in range(i + 1, pbits): le |= (le << (1 << j)) lows.append(le) # 2**n 通りの {0,1} 入力を 2**pbits 通りずつ処理する for i in range(1 << max(n - pbits, 0)): p = lows + [(pfull if ((i >> j) & 1) == 1 else 0) for j in range(n - pbits)] for j, (a, b) in enumerate(net): na = p[a] & p[b] if p[a] != na: p[a], p[b] = na, p[a] | p[b] unused[j] = False for j in range(n - 1): if (p[j] & ~p[j + 1]) != 0: unsorted[j] = True # ソートされていない入力パターンがある場合 if any(unsorted): return IsSortingNg(unsorted) # すべての入力パターンでソートされている場合 return IsSortingOk(unused) # 黄金比 (1+sqrt(5))/2 ≒ 1.618033988749895 PHI = math.sqrt(1.25) + 0.5 # 黄金比 def main(): """テストケースの入出力処理""" t = int(sys.stdin.readline()) assert t <= MAX_TESTCASES cost = 0 for _ in range(t): n, m = map(int, sys.stdin.readline().split()) assert 2 <= n <= MAX_N assert 1 <= m <= n * (n - 1) // 2 cost += m * PHI**n # テストケースの計算量コスト assert cost <= MAX_COST # 1-indexed -> 0-indexed a = map(lambda x: int(x) - 1, sys.stdin.readline().split()) b = map(lambda x: int(x) - 1, sys.stdin.readline().split()) cmps: list[tuple[int, int]] = list(zip(a, b)) assert len(cmps) == m assert all(0 <= a < b < n for a, b in cmps) # is_sorting: 与えられたネットワークが sorting network であるかどうか is_sorting = is_sorting_network(n, cmps) print(is_sorting) # Yes or No if is_sorting: # unused_cmp: 使われない比較器かどうか (is_sorting=True の場合のみ) unused_cmp = is_sorting.get_data() assert len(unused_cmp) == m print(sum(unused_cmp)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp)))) else: # unsorted_pos: ソートされない可能性のある位置 (is_sorting=False の場合のみ) unsorted_pos = is_sorting.get_error() assert len(unsorted_pos) == n - 1 print(sum(unsorted_pos)) print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos)))) main()