結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-03-14 19:36:09
言語 Rust
(1.83.0 + proconio)
結果
AC  
実行時間 83 ms / 2,000 ms
コード長 18,887 bytes
コンパイル時間 15,074 ms
コンパイル使用メモリ 397,712 KB
実行使用メモリ 7,324 KB
最終ジャッジ日時 2025-03-14 19:36:31
合計ジャッジ時間 19,177 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

const PROGRESS_THRESHOLD: usize = 28;

const MAX_T: usize = 1000;
const MAX_N: usize = 64;
const MAX_COST: f64 = 1e17;

type State = u64;

// Fibonacci numbers: FIB1[0] = 1, FIB1[1] = 1, FIB1[i] = FIB1[i-1] + FIB1[i-2] (2 <= i <= State::BITS)
const FIB1: [State; (State::BITS + 1) as usize] = {
    let mut fib = [1; (State::BITS + 1) as usize];
    let mut i = 2;
    while i <= State::BITS as usize {
        fib[i] = fib[i - 1] + fib[i - 2];
        i += 1;
    }
    fib
};

#[derive(Debug, Clone, Copy)]
enum DsuBySizeElement {
    Size(usize),
    Parent(usize),
}
#[derive(Debug, Clone)]
// Disjoint Set Union (Union-Find) by Size
pub struct DsuBySize(Vec<DsuBySizeElement>);
impl DsuBySize {
    pub fn new(n: usize) -> Self {
        Self((0..n).map(|_| DsuBySizeElement::Size(1)).collect())
    }
    pub fn root_size(&mut self, u: usize) -> (usize, usize) {
        match self.0[u] {
            DsuBySizeElement::Size(size) => (u, size),
            DsuBySizeElement::Parent(v) if u == v => (u, 1),
            DsuBySizeElement::Parent(v) => {
                let (root, size) = self.root_size(v);
                self.0[u] = DsuBySizeElement::Parent(root);
                (root, size)
            }
        }
    }
    pub fn unite(&mut self, u: usize, v: usize) -> bool {
        let (u, size_u) = self.root_size(u);
        let (v, size_v) = self.root_size(v);
        if u == v {
            return false;
        }
        if size_u < size_v {
            self.0[u] = DsuBySizeElement::Parent(v);
            self.0[v] = DsuBySizeElement::Size(size_u + size_v);
        } else {
            self.0[v] = DsuBySizeElement::Parent(u);
            self.0[u] = DsuBySizeElement::Size(size_u + size_v);
        }
        true
    }
    pub fn root(&mut self, u: usize) -> usize {
        self.root_size(u).0
    }
    pub fn size(&mut self, u: usize) -> usize {
        self.root_size(u).1
    }
    pub fn equiv(&mut self, u: usize, v: usize) -> bool {
        self.root(u) == self.root(v)
    }
}

#[derive(Clone, Copy)]
struct CeEntry {
    cei: usize,
    a: usize,
    b: usize,
}
impl std::fmt::Debug for CeEntry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "({},{},{})", self.cei, self.a, self.b)
    }
}
#[derive(Debug, Clone)]
enum VerifyJob {
    Cmp {
        root: usize,
        cmp_part: Vec<CeEntry>,
    },
    Combine {
        root_master: usize,
        root_slave: usize,
    },
}

fn verify_strategy(n: usize, cmp: &[(usize, usize)]) -> Vec<VerifyJob> {
    assert!(2 <= n && n <= State::BITS as _);
    debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n));
    let mut cmp_layered = vec![false; cmp.len()];
    let mut cmp_skip = 0usize;
    let mut dsu = DsuBySize::new(n);
    let mut layers = vec![];
    while cmp_skip < cmp.len() {
        let mut node_avail = State::MAX >> (State::BITS - n as u32);
        let mut layer = (0..n).map(|_i| Vec::<CeEntry>::new()).collect::<Vec<_>>();
        let mut combine = (usize::MAX, 0, 0);
        for (i, &(a, b)) in cmp.iter().enumerate().skip(cmp_skip) {
            // if node_avail.count_ones() < 2 { break; }
            debug_assert_eq!(
                node_avail.count_ones() < 2,
                node_avail == 0 || node_avail.is_power_of_two()
            );
            if node_avail == 0 || node_avail.is_power_of_two() {
                break;
            }
            if cmp_layered[i] {
                continue;
            }
            let node_unavail = ((node_avail >> a) & (node_avail >> b) & 1) == 0;
            node_avail &= !((1 as State) << a) & !((1 as State) << b);
            if node_unavail {
                continue;
            }
            if dsu.equiv(a, b) {
                let root_a = dsu.root(a);
                layer[root_a].push(CeEntry { cei: i, a, b });
                cmp_layered[i] = true;
            } else {
                let (root_a, size_a) = dsu.root_size(a);
                let (root_b, size_b) = dsu.root_size(b);
                combine = combine.min((size_a + size_b, root_a, root_b));
            }
        }
        if layer.iter().all(|v| v.is_empty()) {
            // Combine
            let (size, root_a, root_b) = combine;
            if size == usize::MAX {
                break;
            }
            let unite_result = dsu.unite(root_a, root_b);
            assert!(unite_result);
            let root_master = dsu.root(root_a);
            let root_slave = root_a ^ root_b ^ root_master;
            layers.push(VerifyJob::Combine {
                root_master,
                root_slave,
            });
        } else {
            // Comparator
            for (root, ces) in layer.iter().enumerate().filter(|(_, v)| !v.is_empty()) {
                layers.push(VerifyJob::Cmp {
                    root,
                    cmp_part: ces.clone(),
                });
            }
            cmp_skip += cmp_layered
                .iter()
                .skip(cmp_skip)
                .take_while(|&&f| f)
                .count();
        }
    }
    layers
}

// Check if the given comparator network is a sorting network
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,xsave")]
pub unsafe fn is_sorting_network_avx2(n: usize, cmp: &[(usize, usize)]) -> Result<Vec<bool>, Vec<bool>> {
    // Worst-case time complexity: O(FIB1[n] * (m + n*d))
    // m: number of comparators, n: number of input elements, d: depth (number of layers)
    assert!(2 <= n && (n as usize) <= MAX_N && n <= State::BITS as _);
    // Ensure 0-indexed and a < b and b < n
    debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n));
    let mut states = (0..n)
        .map(|i| vec![((1 as State) << i, (1 as State) << i)])
        .collect::<Vec<_>>();
    // unused[i]: whether the i-th element is used in the sorting network
    let mut unused = vec![true; cmp.len()];
    // unsorted[i]: whether the i,(i+1) element pair may be unsorted
    let mut unsorted_i: State = 0;
    let mut dsu = DsuBySize::new(n);
    for job in verify_strategy(n, cmp) {
        match job {
            VerifyJob::Combine {
                root_master,
                root_slave,
            } => {
                assert_eq!(dsu.root(root_master), root_master);
                assert_eq!(dsu.root(root_slave), root_slave);
                let (conn_nodes_master, conn_nodes_slave) =
                    (dsu.size(root_master), dsu.size(root_slave));
                let unite_result = dsu.unite(root_master, root_slave);
                assert!(unite_result);
                assert_eq!(dsu.root(root_master), root_master);
                let conn_nodes_united = dsu.size(root_master);
                let master_len = states[root_master].len();
                let slave_len = states[root_slave].len();
                let mut united_status =
                    Vec::with_capacity(states[root_master].len() * states[root_slave].len());
                for &(sz, so) in states[root_slave].iter() {
                    for &(mz, mo) in states[root_master].iter() {
                        united_status.push((sz | mz, so | mo));
                    }
                }
                let united_len = united_status.len();
                states[root_slave] = vec![];
                states[root_master] = united_status;
                if PROGRESS_THRESHOLD <= n {
                    eprintln!("Combining, conn: {conn_nodes_master}+{conn_nodes_slave}=>{conn_nodes_united}, root: ({root_master},{root_slave}), len: {master_len}*{slave_len}=>{united_len}");
                }
            }
            VerifyJob::Cmp { root, cmp_part } => {
                assert_eq!(dsu.root(root), root);
                assert!(cmp_part
                    .iter()
                    .all(|&CeEntry { cei: _, a, b }| dsu.equiv(root, a) && dsu.equiv(root, b)));
                let conn_nodes = dsu.size(root);
                let pre_len = states[root].len();
                // pre_len * 2 is the upperbound number of next states in the most case (rarely over in the worst case)
                let mut states_next =
                    Vec::with_capacity((FIB1[conn_nodes] as usize).min(pre_len * 2));
                let mut stack = Vec::<(usize, State, State)>::with_capacity(states[root].len() + n);
                for (mut z, mut o) in states[root].iter() {
                    for (i, &CeEntry { cei, a, b }) in cmp_part.iter().enumerate() {
                        if (o >> a) & 1 == 0 || (z >> b) & 1 == 0 {
                            continue;
                        } else if (z >> a) & 1 == 0 || (o >> b) & 1 == 0 {
                            unused[cei] = false;
                            let (xz, xo) = (((z >> a) ^ (z >> b)) & 1, ((o >> a) ^ (o >> b)) & 1);
                            z ^= xz << a | xz << b;
                            o ^= xo << a | xo << b;
                        } else {
                            unused[cei] = false;
                            let (qz, qo) = (z, o ^ ((1 as State) << a) ^ ((1 as State) << b));
                            z ^= (1 as State) << b;
                            stack.push((i + 1, qz, qo));
                        }
                    }
                    states_next.push((z, o));
                }
                while let Some((mut i, mut z, mut o)) = stack.pop() {
                    while let Some(&CeEntry { cei, a, b }) = cmp_part.get(i) {
                        i += 1;
                        if (o >> a) & 1 == 0 || (z >> b) & 1 == 0 {
                            continue;
                        } else if (z >> a) & 1 == 0 || (o >> b) & 1 == 0 {
                            unused[cei] = false;
                            let (xz, xo) = (((z >> a) ^ (z >> b)) & 1, ((o >> a) ^ (o >> b)) & 1);
                            z ^= xz << a | xz << b;
                            o ^= xo << a | xo << b;
                        } else {
                            unused[cei] = false;
                            let (qz, qo) = (z, o ^ ((1 as State) << a) ^ ((1 as State) << b));
                            z ^= (1 as State) << b;
                            stack.push((i, qz, qo));
                        }
                    }
                    states_next.push((z, o));
                }
                let gen_len = states_next.len();
                // dedupulicate
                states_next.sort_unstable();
                states_next.dedup();
                let dedup_len = states_next.len();
                // write back next states
                states[root] = states_next;
                if PROGRESS_THRESHOLD <= n {
                    eprintln!(
                        "AppliedCE, conn: {conn_nodes}, root: {root}, len: {pre_len}=>{gen_len}=>{dedup_len}, cmp: {cmp_part:?}"
                    );
                }
            }
        }
    }
    for states_parroot in states.iter() {
        let n1_mask = State::MAX >> (State::BITS - (n - 1) as u32);
        let q_mask = states_parroot.first().map(|&(z, o)| z | o).unwrap_or(0);
        unsorted_i |= (q_mask & (!q_mask >> 1)) & n1_mask;
        for &(z, o) in states_parroot.iter() {
            unsorted_i |= o & (z >> 1);
        }
    }
    // All branches have ended
    if PROGRESS_THRESHOLD <= n {
        eprintln!();
    }
    // If any branch is not sorted, unsorted_i will be non-zero, indicating it is not a sorting network
    if unsorted_i != 0 {
        // Return positions that may be unsorted
        Err(Vec::from_iter(
            (0..n - 1).map(|k| (unsorted_i >> k) & 1 != 0),
        ))
    } else {
        // If all branches are sorted, it is a sorting network
        // Return unused comparators
        Ok(unused)
    }
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,xsave")]
pub unsafe fn is_sorting_network_pow2_avx2(
    n: usize,
    cmp: &[(usize, usize)],
) -> Result<Vec<bool>, Vec<bool>> {
    assert!(2 <= n && n <= u32::BITS as _ && n <= MAX_N);
    // Ensure 0-indexed and a < b and b < n
    debug_assert!(cmp.iter().all(|&(a, b)| a < b && b < n));
    const fn s_upper(n: u32) -> [u64; 16] {
        let mut s = [0; 16];
        let mut i = 0usize;
        while i < 16 {
            s[i] = ((i as u64 >> n) & 1).wrapping_neg();
            i += 1;
        }
        s
    }
    const S_LOW: [u64; 6] = [
        0xaaaaaaaaaaaaaaaa,
        0xcccccccccccccccc,
        0xf0f0f0f0f0f0f0f0,
        0xff00ff00ff00ff00,
        0xffff0000ffff0000,
        0xffffffff00000000,
    ];
    const S_UPPER: [[u64; 16]; 4] = [s_upper(0), s_upper(1), s_upper(2), s_upper(3)];
    let mut states: [[u64; 16]; MAX_N] = [[0; 16]; MAX_N];
    // unused[i]: whether the i-th element is used in the sorting network
    let mut unused = vec![true; cmp.len()];
    // unsorted[i]: whether the i,(i+1) element pair may be unsorted
    let mut unsorted = vec![false; (n - 1) as _];

    for i in 0..(1u64 << n.saturating_sub(10)) {
        for (se, &le) in states.iter_mut().zip(S_LOW.iter()) {
            se.fill(le);
        }
        states[6] = S_UPPER[0];
        states[7] = S_UPPER[1];
        states[8] = S_UPPER[2];
        states[9] = S_UPPER[3];
        for j in 10..n {
            let v = ((i >> (j - 10)) & 1).wrapping_neg();
            states[j].fill(v);
        }
        for (&(a, b), unused_j) in cmp.iter().zip(unused.iter_mut()) {
            let (va, vb) = (&states[a], &states[b]);
            let (na, nb): ([u64; 16], [u64; 16]) = {
                use std::arch::x86_64::*;
                use std::mem::transmute;
                let va: [__m256i; 4] = transmute(*va);
                let vb: [__m256i; 4] = transmute(*vb);
                (
                    transmute([
                        _mm256_and_si256(va[0], vb[0]),
                        _mm256_and_si256(va[1], vb[1]),
                        _mm256_and_si256(va[2], vb[2]),
                        _mm256_and_si256(va[3], vb[3]),
                    ]),
                    transmute([
                        _mm256_or_si256(va[0], vb[0]),
                        _mm256_or_si256(va[1], vb[1]),
                        _mm256_or_si256(va[2], vb[2]),
                        _mm256_or_si256(va[3], vb[3]),
                    ]),
                )
            };
            debug_assert_eq!(std::array::from_fn(|i| va[i] & vb[i]), na);
            debug_assert_eq!(std::array::from_fn(|i| va[i] | vb[i]), nb);
            //let na = std::array::from_fn(|i| va[i] & vb[i]);
            //let nb = std::array::from_fn(|i| va[i] | vb[i]);
            if *va != na {
                *unused_j = false;
            }
            states[a] = na;
            states[b] = nb;
        }
        for (se, unsorted_k) in states[..n].windows(2).zip(unsorted.iter_mut()) {
            if se[0].iter().zip(se[1].iter()).any(|(&a, &b)| (a & !b) != 0) {
                *unsorted_k = true;
            }
        }
    }

    // If any branch is not sorted, unsorted will contain true, indicating it is not a sorting network
    if unsorted.iter().any(|&f| f) {
        // Return positions that may be unsorted
        Err(unsorted)
    } else {
        // If all branches are sorted, it is a sorting network
        // Return unused comparators
        Ok(unused)
    }
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    use std::io::Write;
    let execution_start = std::time::Instant::now();
    let stdin = std::io::stdin();
    let mut lines = std::io::BufRead::lines(stdin.lock());
    let mut bout = std::io::BufWriter::new(std::io::stdout());

    let t: usize = lines.next().unwrap()?.trim().parse()?;
    assert!(t <= MAX_T);

    // φ = (1 + √5) / 2 : golden ratio 1.618033988749895
    let phi = (1.25f64).sqrt() + 0.5;
    let mut cost = 0f64;

    for _ in 0..t {
        let line = lines.next().unwrap()?;
        let mut parts = line.split_whitespace();
        let n: usize = parts.next().unwrap().parse()?;
        let m: usize = parts.next().unwrap().parse()?;
        assert!(2 <= n && (n as usize) <= MAX_N);
        assert!(1 <= m && m <= (n as usize) * ((n as usize) - 1) / 2);
        cost += m as f64 * phi.powi(n as i32);
        // Test case cost <= MAX_COST
        assert!(cost <= MAX_COST);

        // Read comparators
        let vec_a = lines
            .next()
            .unwrap()?
            .split_whitespace()
            .map(|s| s.parse::<usize>().unwrap())
            .collect::<Vec<_>>();
        let vec_b = lines
            .next()
            .unwrap()?
            .split_whitespace()
            .map(|s| s.parse::<usize>().unwrap())
            .collect::<Vec<_>>();
        assert!(vec_a.len() == m && vec_b.len() == m);
        assert!(vec_a.iter().all(|&a| 1 <= a && a <= n));
        assert!(vec_b.iter().all(|&b| 1 <= b && b <= n));
        let cmp = vec_a
            .iter()
            .zip(vec_b.iter())
            .map(|(&a, &b)| ((a - 1) as usize, (b - 1) as usize))
            .collect::<Vec<_>>();
        assert!(cmp.len() == m);
        assert!(cmp.iter().all(|&(a, b)| a < b));

        // Check if it is a sorting network
        assert!(std::arch::is_x86_feature_detected!("avx2"));
        let result = if n <= 17 {
            unsafe { is_sorting_network_pow2_avx2(n, &cmp) }
        } else {
            unsafe { is_sorting_network_avx2(n, &cmp) }
        };
        match result {
            Ok(unused) => {
                writeln!(&mut bout, "Yes")?;
                // List unused comparators j
                writeln!(&mut bout, "{}", unused.iter().filter(|&&f| f).count())?;
                // 1-indexed
                writeln!(
                    &mut bout,
                    "{}",
                    unused
                        .iter()
                        .enumerate()
                        .filter_map(|(j, &u)| if u { Some((j + 1).to_string()) } else { None })
                        .collect::<Vec<_>>()
                        .join(" ")
                )?;
            }
            Err(unsorted) => {
                writeln!(&mut bout, "No")?;
                // List positions k that may be unsorted
                writeln!(&mut bout, "{}", unsorted.iter().filter(|&&f| f).count())?;
                // 1-indexed
                writeln!(
                    &mut bout,
                    "{}",
                    unsorted
                        .iter()
                        .enumerate()
                        .filter_map(|(k, &u)| if u { Some((k + 1).to_string()) } else { None })
                        .collect::<Vec<_>>()
                        .join(" ")
                )?;
            }
        }
    }
    bout.flush()?;
    eprintln!("{:.6}[s]", execution_start.elapsed().as_secs_f64());
    Ok(())
}
0