結果
| 問題 |
No.3047 Verification of Sorting Network
|
| ユーザー |
👑 |
| 提出日時 | 2025-02-10 14:45:04 |
| 言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
| 結果 |
AC
|
| 実行時間 | 1,422 ms / 2,000 ms |
| コード長 | 4,015 bytes |
| コンパイル時間 | 382 ms |
| コンパイル使用メモリ | 12,416 KB |
| 実行使用メモリ | 11,136 KB |
| 最終ジャッジ日時 | 2025-03-05 20:39:40 |
| 合計ジャッジ時間 | 58,730 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 61 |
ソースコード
"""
yukicoder Problem: Verify Sorting Network
O(m*2^n)解法 + 32768bit並列化
"""
import math
import sys
MAX_TESTCASES = 10000
MAX_N = 27
MAX_COST = 1e8
STATE_ZERO = 0
STATE_ONE = 1
STATE_UNKNWON = 2
class IsSortingOk:
"""is_sorting_network Ok type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return True
def __str__(self):
return 'Yes'
def get_data(self):
"""get data value"""
return self.value
class IsSortingNg:
"""is_sorting_network Ng type"""
def __init__(self, value: list[bool]):
self.value = value
def __bool__(self):
return False
def __str__(self):
return 'No'
def get_error(self):
"""get error value"""
return self.value
def is_sorting_network(n: int, net: list[tuple[int, int]]) -> IsSortingOk | IsSortingNg:
"""
与えられたネットワークが sorting network であるかどうかを調べます。
Args:
n: 入力の数
net: 比較器のリスト
"""
assert 2 <= n
# 0-indexed 入力の範囲を確認
assert all(0 <= a < b < n for a, b in net)
# 比較器の数
m = len(net)
unused = [True] * m
unsorted = [False] * (n - 1)
# (2**pbits)-bit の状態ベクトルを並列に処理する
pbits = 15
# (2**pbits)-bit の 1 が立っているビット列
pfull = (1 << (1 << pbits)) - 1
lows = []
for i in range(pbits):
le = ((1 << (1 << i)) - 1) << (1 << i)
for j in range(i + 1, pbits):
le |= (le << (1 << j))
lows.append(le)
# 2**n 通りの {0,1} 入力を 2**pbits 通りずつ処理する
for i in range(1 << max(n - pbits, 0)):
p = lows + [(pfull if ((i >> j) & 1) == 1 else 0) for j in range(n - pbits)]
for j, (a, b) in enumerate(net):
na = p[a] & p[b]
if p[a] != na:
p[a], p[b] = na, p[a] | p[b]
unused[j] = False
for j in range(n - 1):
if (p[j] & ~p[j + 1]) != 0:
unsorted[j] = True
# ソートされていない入力パターンがある場合
if any(unsorted):
return IsSortingNg(unsorted)
# すべての入力パターンでソートされている場合
return IsSortingOk(unused)
# 黄金比 (1+sqrt(5))/2 ≒ 1.618033988749895
PHI = math.sqrt(1.25) + 0.5 # 黄金比
def main():
"""テストケースの入出力処理"""
t = int(sys.stdin.readline())
assert t <= MAX_TESTCASES
cost = 0
for _ in range(t):
n, m = map(int, sys.stdin.readline().split())
assert 2 <= n <= MAX_N
assert 1 <= m <= n * (n - 1) // 2
cost += m * PHI**n # テストケースの計算量コスト
assert cost <= MAX_COST
# 1-indexed -> 0-indexed
a = map(lambda x: int(x) - 1, sys.stdin.readline().split())
b = map(lambda x: int(x) - 1, sys.stdin.readline().split())
cmps: list[tuple[int, int]] = list(zip(a, b))
assert len(cmps) == m
assert all(0 <= a < b < n for a, b in cmps)
# is_sorting: 与えられたネットワークが sorting network であるかどうか
is_sorting = is_sorting_network(n, cmps)
print(is_sorting) # Yes or No
if is_sorting:
# unused_cmp: 使われない比較器かどうか (is_sorting=True の場合のみ)
unused_cmp = is_sorting.get_data()
assert len(unused_cmp) == m
print(sum(unused_cmp))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unused_cmp))))
else:
# unsorted_pos: ソートされない可能性のある位置 (is_sorting=False の場合のみ)
unsorted_pos = is_sorting.get_error()
assert len(unsorted_pos) == n - 1
print(sum(unsorted_pos))
print(*map(lambda e: e[0] + 1, filter(lambda e: e[1], enumerate(unsorted_pos))))
main()