結果
問題 |
No.3047 Verification of Sorting Network
|
ユーザー |
👑 |
提出日時 | 2025-03-14 19:36:09 |
言語 | Rust (1.83.0 + proconio) |
結果 |
AC
|
実行時間 | 83 ms / 2,000 ms |
コード長 | 18,887 bytes |
コンパイル時間 | 15,074 ms |
コンパイル使用メモリ | 397,712 KB |
実行使用メモリ | 7,324 KB |
最終ジャッジ日時 | 2025-03-14 19:36:31 |
合計ジャッジ時間 | 19,177 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
ソースコード
const PROGRESS_THRESHOLD: usize = 28; const MAX_T: usize = 1000; const MAX_N: usize = 64; const MAX_COST: f64 = 1e17; type State = u64; // Fibonacci numbers: FIB1[0] = 1, FIB1[1] = 1, FIB1[i] = FIB1[i-1] + FIB1[i-2] (2 <= i <= State::BITS) const FIB1: [State; (State::BITS + 1) as usize] = { let mut fib = [1; (State::BITS + 1) as usize]; let mut i = 2; while i <= State::BITS as usize { fib[i] = fib[i - 1] + fib[i - 2]; i += 1; } fib }; #[derive(Debug, Clone, Copy)] enum DsuBySizeElement { Size(usize), Parent(usize), } #[derive(Debug, Clone)] // Disjoint Set Union (Union-Find) by Size pub struct DsuBySize(Vec<DsuBySizeElement>); impl DsuBySize { pub fn new(n: usize) -> Self { Self((0..n).map(|_| DsuBySizeElement::Size(1)).collect()) } pub fn root_size(&mut self, u: usize) -> (usize, usize) { match self.0[u] { DsuBySizeElement::Size(size) => (u, size), DsuBySizeElement::Parent(v) if u == v => (u, 1), DsuBySizeElement::Parent(v) => { let (root, size) = self.root_size(v); self.0[u] = DsuBySizeElement::Parent(root); (root, size) } } } pub fn unite(&mut self, u: usize, v: usize) -> bool { let (u, size_u) = self.root_size(u); let (v, size_v) = self.root_size(v); if u == v { return false; } if size_u < size_v { self.0[u] = DsuBySizeElement::Parent(v); self.0[v] = DsuBySizeElement::Size(size_u + size_v); } else { self.0[v] = DsuBySizeElement::Parent(u); self.0[u] = DsuBySizeElement::Size(size_u + size_v); } true } pub fn root(&mut self, u: usize) -> usize { self.root_size(u).0 } pub fn size(&mut self, u: usize) -> usize { self.root_size(u).1 } pub fn equiv(&mut self, u: usize, v: usize) -> bool { self.root(u) == self.root(v) } } #[derive(Clone, Copy)] struct CeEntry { cei: usize, a: usize, b: usize, } impl std::fmt::Debug for CeEntry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "({},{},{})", self.cei, self.a, self.b) } } #[derive(Debug, Clone)] enum VerifyJob { Cmp { root: usize, cmp_part: Vec<CeEntry>, }, Combine { root_master: usize, root_slave: usize, }, } fn verify_strategy(n: usize, cmp: &[(usize, usize)]) -> Vec<VerifyJob> { assert!(2 <= n && n <= State::BITS as _); debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n)); let mut cmp_layered = vec![false; cmp.len()]; let mut cmp_skip = 0usize; let mut dsu = DsuBySize::new(n); let mut layers = vec![]; while cmp_skip < cmp.len() { let mut node_avail = State::MAX >> (State::BITS - n as u32); let mut layer = (0..n).map(|_i| Vec::<CeEntry>::new()).collect::<Vec<_>>(); let mut combine = (usize::MAX, 0, 0); for (i, &(a, b)) in cmp.iter().enumerate().skip(cmp_skip) { // if node_avail.count_ones() < 2 { break; } debug_assert_eq!( node_avail.count_ones() < 2, node_avail == 0 || node_avail.is_power_of_two() ); if node_avail == 0 || node_avail.is_power_of_two() { break; } if cmp_layered[i] { continue; } let node_unavail = ((node_avail >> a) & (node_avail >> b) & 1) == 0; node_avail &= !((1 as State) << a) & !((1 as State) << b); if node_unavail { continue; } if dsu.equiv(a, b) { let root_a = dsu.root(a); layer[root_a].push(CeEntry { cei: i, a, b }); cmp_layered[i] = true; } else { let (root_a, size_a) = dsu.root_size(a); let (root_b, size_b) = dsu.root_size(b); combine = combine.min((size_a + size_b, root_a, root_b)); } } if layer.iter().all(|v| v.is_empty()) { // Combine let (size, root_a, root_b) = combine; if size == usize::MAX { break; } let unite_result = dsu.unite(root_a, root_b); assert!(unite_result); let root_master = dsu.root(root_a); let root_slave = root_a ^ root_b ^ root_master; layers.push(VerifyJob::Combine { root_master, root_slave, }); } else { // Comparator for (root, ces) in layer.iter().enumerate().filter(|(_, v)| !v.is_empty()) { layers.push(VerifyJob::Cmp { root, cmp_part: ces.clone(), }); } cmp_skip += cmp_layered .iter() .skip(cmp_skip) .take_while(|&&f| f) .count(); } } layers } // Check if the given comparator network is a sorting network #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,xsave")] pub unsafe fn is_sorting_network_avx2(n: usize, cmp: &[(usize, usize)]) -> Result<Vec<bool>, Vec<bool>> { // Worst-case time complexity: O(FIB1[n] * (m + n*d)) // m: number of comparators, n: number of input elements, d: depth (number of layers) assert!(2 <= n && (n as usize) <= MAX_N && n <= State::BITS as _); // Ensure 0-indexed and a < b and b < n debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n)); let mut states = (0..n) .map(|i| vec![((1 as State) << i, (1 as State) << i)]) .collect::<Vec<_>>(); // unused[i]: whether the i-th element is used in the sorting network let mut unused = vec![true; cmp.len()]; // unsorted[i]: whether the i,(i+1) element pair may be unsorted let mut unsorted_i: State = 0; let mut dsu = DsuBySize::new(n); for job in verify_strategy(n, cmp) { match job { VerifyJob::Combine { root_master, root_slave, } => { assert_eq!(dsu.root(root_master), root_master); assert_eq!(dsu.root(root_slave), root_slave); let (conn_nodes_master, conn_nodes_slave) = (dsu.size(root_master), dsu.size(root_slave)); let unite_result = dsu.unite(root_master, root_slave); assert!(unite_result); assert_eq!(dsu.root(root_master), root_master); let conn_nodes_united = dsu.size(root_master); let master_len = states[root_master].len(); let slave_len = states[root_slave].len(); let mut united_status = Vec::with_capacity(states[root_master].len() * states[root_slave].len()); for &(sz, so) in states[root_slave].iter() { for &(mz, mo) in states[root_master].iter() { united_status.push((sz | mz, so | mo)); } } let united_len = united_status.len(); states[root_slave] = vec![]; states[root_master] = united_status; if PROGRESS_THRESHOLD <= n { eprintln!("Combining, conn: {conn_nodes_master}+{conn_nodes_slave}=>{conn_nodes_united}, root: ({root_master},{root_slave}), len: {master_len}*{slave_len}=>{united_len}"); } } VerifyJob::Cmp { root, cmp_part } => { assert_eq!(dsu.root(root), root); assert!(cmp_part .iter() .all(|&CeEntry { cei: _, a, b }| dsu.equiv(root, a) && dsu.equiv(root, b))); let conn_nodes = dsu.size(root); let pre_len = states[root].len(); // pre_len * 2 is the upperbound number of next states in the most case (rarely over in the worst case) let mut states_next = Vec::with_capacity((FIB1[conn_nodes] as usize).min(pre_len * 2)); let mut stack = Vec::<(usize, State, State)>::with_capacity(states[root].len() + n); for (mut z, mut o) in states[root].iter() { for (i, &CeEntry { cei, a, b }) in cmp_part.iter().enumerate() { if (o >> a) & 1 == 0 || (z >> b) & 1 == 0 { continue; } else if (z >> a) & 1 == 0 || (o >> b) & 1 == 0 { unused[cei] = false; let (xz, xo) = (((z >> a) ^ (z >> b)) & 1, ((o >> a) ^ (o >> b)) & 1); z ^= xz << a | xz << b; o ^= xo << a | xo << b; } else { unused[cei] = false; let (qz, qo) = (z, o ^ ((1 as State) << a) ^ ((1 as State) << b)); z ^= (1 as State) << b; stack.push((i + 1, qz, qo)); } } states_next.push((z, o)); } while let Some((mut i, mut z, mut o)) = stack.pop() { while let Some(&CeEntry { cei, a, b }) = cmp_part.get(i) { i += 1; if (o >> a) & 1 == 0 || (z >> b) & 1 == 0 { continue; } else if (z >> a) & 1 == 0 || (o >> b) & 1 == 0 { unused[cei] = false; let (xz, xo) = (((z >> a) ^ (z >> b)) & 1, ((o >> a) ^ (o >> b)) & 1); z ^= xz << a | xz << b; o ^= xo << a | xo << b; } else { unused[cei] = false; let (qz, qo) = (z, o ^ ((1 as State) << a) ^ ((1 as State) << b)); z ^= (1 as State) << b; stack.push((i, qz, qo)); } } states_next.push((z, o)); } let gen_len = states_next.len(); // dedupulicate states_next.sort_unstable(); states_next.dedup(); let dedup_len = states_next.len(); // write back next states states[root] = states_next; if PROGRESS_THRESHOLD <= n { eprintln!( "AppliedCE, conn: {conn_nodes}, root: {root}, len: {pre_len}=>{gen_len}=>{dedup_len}, cmp: {cmp_part:?}" ); } } } } for states_parroot in states.iter() { let n1_mask = State::MAX >> (State::BITS - (n - 1) as u32); let q_mask = states_parroot.first().map(|&(z, o)| z | o).unwrap_or(0); unsorted_i |= (q_mask & (!q_mask >> 1)) & n1_mask; for &(z, o) in states_parroot.iter() { unsorted_i |= o & (z >> 1); } } // All branches have ended if PROGRESS_THRESHOLD <= n { eprintln!(); } // If any branch is not sorted, unsorted_i will be non-zero, indicating it is not a sorting network if unsorted_i != 0 { // Return positions that may be unsorted Err(Vec::from_iter( (0..n - 1).map(|k| (unsorted_i >> k) & 1 != 0), )) } else { // If all branches are sorted, it is a sorting network // Return unused comparators Ok(unused) } } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,xsave")] pub unsafe fn is_sorting_network_pow2_avx2( n: usize, cmp: &[(usize, usize)], ) -> Result<Vec<bool>, Vec<bool>> { assert!(2 <= n && n <= u32::BITS as _ && n <= MAX_N); // Ensure 0-indexed and a < b and b < n debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n)); const fn s_upper(n: u32) -> [u64; 16] { let mut s = [0; 16]; let mut i = 0usize; while i < 16 { s[i] = ((i as u64 >> n) & 1).wrapping_neg(); i += 1; } s } const S_LOW: [u64; 6] = [ 0xaaaaaaaaaaaaaaaa, 0xcccccccccccccccc, 0xf0f0f0f0f0f0f0f0, 0xff00ff00ff00ff00, 0xffff0000ffff0000, 0xffffffff00000000, ]; const S_UPPER: [[u64; 16]; 4] = [s_upper(0), s_upper(1), s_upper(2), s_upper(3)]; let mut states: [[u64; 16]; MAX_N] = [[0; 16]; MAX_N]; // unused[i]: whether the i-th element is used in the sorting network let mut unused = vec![true; cmp.len()]; // unsorted[i]: whether the i,(i+1) element pair may be unsorted let mut unsorted = vec![false; (n - 1) as _]; for i in 0..(1u64 << n.saturating_sub(10)) { for (se, &le) in states.iter_mut().zip(S_LOW.iter()) { se.fill(le); } states[6] = S_UPPER[0]; states[7] = S_UPPER[1]; states[8] = S_UPPER[2]; states[9] = S_UPPER[3]; for j in 10..n { let v = ((i >> (j - 10)) & 1).wrapping_neg(); states[j].fill(v); } for (&(a, b), unused_j) in cmp.iter().zip(unused.iter_mut()) { let (va, vb) = (&states[a], &states[b]); let (na, nb): ([u64; 16], [u64; 16]) = { use std::arch::x86_64::*; use std::mem::transmute; let va: [__m256i; 4] = transmute(*va); let vb: [__m256i; 4] = transmute(*vb); ( transmute([ _mm256_and_si256(va[0], vb[0]), _mm256_and_si256(va[1], vb[1]), _mm256_and_si256(va[2], vb[2]), _mm256_and_si256(va[3], vb[3]), ]), transmute([ _mm256_or_si256(va[0], vb[0]), _mm256_or_si256(va[1], vb[1]), _mm256_or_si256(va[2], vb[2]), _mm256_or_si256(va[3], vb[3]), ]), ) }; debug_assert_eq!(std::array::from_fn(|i| va[i] & vb[i]), na); debug_assert_eq!(std::array::from_fn(|i| va[i] | vb[i]), nb); //let na = std::array::from_fn(|i| va[i] & vb[i]); //let nb = std::array::from_fn(|i| va[i] | vb[i]); if *va != na { *unused_j = false; } states[a] = na; states[b] = nb; } for (se, unsorted_k) in states[..n].windows(2).zip(unsorted.iter_mut()) { if se[0].iter().zip(se[1].iter()).any(|(&a, &b)| (a & !b) != 0) { *unsorted_k = true; } } } // If any branch is not sorted, unsorted will contain true, indicating it is not a sorting network if unsorted.iter().any(|&f| f) { // Return positions that may be unsorted Err(unsorted) } else { // If all branches are sorted, it is a sorting network // Return unused comparators Ok(unused) } } fn main() -> Result<(), Box<dyn std::error::Error>> { use std::io::Write; let execution_start = std::time::Instant::now(); let stdin = std::io::stdin(); let mut lines = std::io::BufRead::lines(stdin.lock()); let mut bout = std::io::BufWriter::new(std::io::stdout()); let t: usize = lines.next().unwrap()?.trim().parse()?; assert!(t <= MAX_T); // φ = (1 + √5) / 2 : golden ratio 1.618033988749895 let phi = (1.25f64).sqrt() + 0.5; let mut cost = 0f64; for _ in 0..t { let line = lines.next().unwrap()?; let mut parts = line.split_whitespace(); let n: usize = parts.next().unwrap().parse()?; let m: usize = parts.next().unwrap().parse()?; assert!(2 <= n && (n as usize) <= MAX_N); assert!(1 <= m && m <= (n as usize) * ((n as usize) - 1) / 2); cost += m as f64 * phi.powi(n as i32); // Test case cost <= MAX_COST assert!(cost <= MAX_COST); // Read comparators let vec_a = lines .next() .unwrap()? .split_whitespace() .map(|s| s.parse::<usize>().unwrap()) .collect::<Vec<_>>(); let vec_b = lines .next() .unwrap()? .split_whitespace() .map(|s| s.parse::<usize>().unwrap()) .collect::<Vec<_>>(); assert!(vec_a.len() == m && vec_b.len() == m); assert!(vec_a.iter().all(|&a| 1 <= a && a <= n)); assert!(vec_b.iter().all(|&b| 1 <= b && b <= n)); let cmp = vec_a .iter() .zip(vec_b.iter()) .map(|(&a, &b)| ((a - 1) as usize, (b - 1) as usize)) .collect::<Vec<_>>(); assert!(cmp.len() == m); assert!(cmp.iter().all(|&(a, b)| a < b)); // Check if it is a sorting network assert!(std::arch::is_x86_feature_detected!("avx2")); let result = if n <= 17 { unsafe { is_sorting_network_pow2_avx2(n, &cmp) } } else { unsafe { is_sorting_network_avx2(n, &cmp) } }; match result { Ok(unused) => { writeln!(&mut bout, "Yes")?; // List unused comparators j writeln!(&mut bout, "{}", unused.iter().filter(|&&f| f).count())?; // 1-indexed writeln!( &mut bout, "{}", unused .iter() .enumerate() .filter_map(|(j, &u)| if u { Some((j + 1).to_string()) } else { None }) .collect::<Vec<_>>() .join(" ") )?; } Err(unsorted) => { writeln!(&mut bout, "No")?; // List positions k that may be unsorted writeln!(&mut bout, "{}", unsorted.iter().filter(|&&f| f).count())?; // 1-indexed writeln!( &mut bout, "{}", unsorted .iter() .enumerate() .filter_map(|(k, &u)| if u { Some((k + 1).to_string()) } else { None }) .collect::<Vec<_>>() .join(" ") )?; } } } bout.flush()?; eprintln!("{:.6}[s]", execution_start.elapsed().as_secs_f64()); Ok(()) }