結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-03-13 11:13:41
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 90 ms / 2,000 ms
コード長 15,395 bytes
コンパイル時間 6,469 ms
コンパイル使用メモリ 341,680 KB
実行使用メモリ 7,328 KB
最終ジャッジ日時 2025-03-13 11:22:27
合計ジャッジ時間 11,413 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h> // all
#pragma GCC optimize ("O3")
#pragma GCC target ("arch=x86-64-v3,tune=native")
using namespace std;

#define PROGRESS_THRESHOLD 28
#define MAX_T 1000
#define MAX_N 64
#define MAX_COST 1e17

// using State = uint32_t;
using State = uint64_t;

// Calculate Fibonacci sequence values Fib1[0] = 1, Fib1[1] = 1, Fib1[n] = Fib1[n-1] + Fib1[n-2]
constexpr array<State, numeric_limits<State>::digits + 1> fib1_gen() {
    array<State, numeric_limits<State>::digits + 1> fib1 = {1, 1};
    for (int i = 2; i < numeric_limits<State>::digits + 1; ++i) {
        fib1[i] = fib1[i - 1] + fib1[i - 2];
    }
    return fib1;
}
const array<State, numeric_limits<State>::digits + 1> fib1 = fib1_gen();

// disjoint-set union (union-find)
struct Dsu {
    vector<int> data;
    Dsu(int n) : data(n, -1) {}
    bool unite(int a, int b) {
        a = root(a);
        b = root(b);
        if (a != b) {
            if (data[a] > data[b]) {
                swap(a, b);
            }
            data[a] += data[b];
            data[b] = a;
        }
        return a != b;
    }
    bool equiv(int a, int b) {
        return root(a) == root(b);
    }
    int root(int a) {
        return data[a] < 0 ? a : data[a] = root(data[a]);
    }
    int size(int a) {
        return -data[root(a)];
    }
};

struct CeEntry {
    size_t cei; // comparator index
    size_t a, b;
};

struct CmpLayerCombine {
    size_t root_master, root_slave;
};

struct CmpLayerCmp {
    size_t root;
    vector<CeEntry> cmp_part;
};

using CmpLayer = variant<CmpLayerCombine, CmpLayerCmp>;

vector<CmpLayer> verify_layers(const size_t n, const vector<pair<size_t, size_t>> &cmps) {
    vector<bool> cmp_layered(cmps.size(), false);
    size_t cmp_skip = 0;
    Dsu dsu(n);
    vector<CmpLayer> layers;
    while (cmp_skip < cmps.size()) {
        vector<bool> layer_checked(n, false);
        vector<vector<CeEntry>> layer(n);
        tuple<size_t, size_t, size_t> combine = {n + 1, 0, 0};
        for (size_t cei = cmp_skip; cei < cmps.size(); ++cei) {
            if (cmp_layered[cei]) {
                continue;
            }
            auto [a, b] = cmps[cei];
            bool checked = layer_checked[a] || layer_checked[b];
            layer_checked[a] = layer_checked[b] = true;
            if (checked) {
                continue;
            }
            if (dsu.equiv(a, b)) {
                size_t root_a = dsu.root(a);
                layer[root_a].emplace_back(CeEntry{cei, a, b});
                cmp_layered[cei] = true;
            } else {
                size_t root_a = dsu.root(a);
                size_t root_b = dsu.root(b);
                size_t size_a = dsu.size(a);
                size_t size_b = dsu.size(b);
                combine = min(combine, make_tuple(size_a + size_b, root_a, root_b));
            }
        }
        if (all_of(layer.begin(), layer.end(), [](const auto &v) { return v.empty(); })) {
            size_t united_size, root_a, root_b;
            tie(united_size, root_a, root_b) = combine;
            if (n < united_size) {
                break;
            }
            dsu.unite(root_a, root_b);
            size_t root_master = dsu.root(root_a);
            size_t root_slave = root_a == root_master ? root_b : root_a;
            layers.emplace_back(CmpLayerCombine{root_master, root_slave});
        } else {
            for (size_t root = 0; root < n; ++root) {
                if (layer[root].empty()) {
                    continue;
                }
                layers.emplace_back(CmpLayerCmp{root, move(layer[root])});
            }
            for (; cmp_skip < cmps.size(); ++cmp_skip) {
                if (!cmp_layered[cmp_skip]) {
                    break;
                }
            }
        }
    }
    return layers;
}

// Check if the given comparator network is a sorting network
// return: if it is a sorting network, return the unused comparators, otherwise return the positions that may not be sorted
expected<vector<bool>, vector<bool>>
is_sorting_network(const size_t n, const vector<pair<size_t, size_t>> &cmps) {
    const State fullbit = ~(State)0;
    const int bits = numeric_limits<State>::digits;
    assert(2 <= n && n <= bits);
    size_t m = cmps.size();
    for (auto [a, b] : cmps) {
        // Ensure 0-indexed and a < b and b < n
        assert(0 <= a && a < b && b < n);
    }
    // unused[i] is true if the i-th comparator is unused
    // unsorted[i] is true if the i-th and (i+1)-th elements are not in ascending order after passing through all comparators
    vector<bool> unused(m, true);
    State unsorted_i = 0;

    vector<vector<pair<State, State>>> states(n);
    for (size_t i = 0; i < n; ++i) {
        states[i].emplace_back(State(1) << i, State(1) << i);
    }
    Dsu dsu(n);

    vector<CmpLayer> cmp_layers = verify_layers(n, cmps);
    for (auto &job : cmp_layers) {
        if (holds_alternative<CmpLayerCombine>(job)) {
            CmpLayerCombine &combine = get<CmpLayerCombine>(job);
            size_t root_master = combine.root_master;
            size_t root_slave = combine.root_slave;
            size_t size_master = dsu.size(root_master);
            size_t size_slave = dsu.size(root_slave);
            dsu.unite(root_master, root_slave);
            size_t size_united = dsu.size(root_master);
            vector<pair<State, State>> united_status;
            size_t len_master = states[root_master].size();
            size_t len_slave = states[root_slave].size();
            united_status.reserve(len_master * len_slave);
            for (const auto &[sz, so] : states[root_slave]) {
                for (const auto &[mz, mo] : states[root_master]) {
                    united_status.emplace_back(mz | sz, mo | so);
                }
            }
            size_t len_united = united_status.size();
            states[root_master] = move(united_status);
            states[root_slave].clear();
            if (PROGRESS_THRESHOLD <= n) {
                cerr << "Combining, sizes: " << size_master << "+" << size_slave << "=>" << size_united
                << ", len: " << len_master << "*" << len_slave << "=>" << len_united
                << ", root_master: " << root_master << ", root_slave: " << root_slave << endl;
            }
        } else {
            CmpLayerCmp &cmp = get<CmpLayerCmp>(job);
            size_t root = cmp.root;
            size_t root_size = dsu.size(root);
            vector<CeEntry> &cmp_part = cmp.cmp_part;
            vector<pair<State, State>> states_next;
            vector<tuple<size_t, State, State>> stack;
            size_t len_pre = states[root].size();
            states_next.reserve(len_pre * 2);
            for (const auto &st : states[root]) {
                auto [z, o] = st;
                for (size_t i = 0; const auto [cei, a, b] : cmp_part) {
                    i += 1;
                    if (((o >> a) & 1) == 0 || ((z >> b) & 1) == 0) {
                        continue;
                    } else if (((z >> a) & 1) == 0 || ((o >> b) & 1) == 0) {
                        unused[cei] = false;
                        State xz = ((z >> a) ^ (z >> b)) & 1;
                        State xo = ((o >> a) ^ (o >> b)) & 1;
                        z ^= (xz << a) | (xz << b);
                        o ^= (xo << a) | (xo << b);
                    } else {
                        unused[cei] = false;
                        State qz = z, qo = (o ^ ((State)1 << a) ^ ((State)1 << b));
                        z ^= ((State)1 << b);
                        stack.emplace_back(i, qz, qo);
                    }
                }
                states_next.emplace_back(z, o);
            }
            while (!stack.empty()) {
                auto [i, z, o] = stack.back();
                stack.pop_back();
                for (; i < cmp_part.size(); ++i) {
                    auto [cei, a, b] = cmp_part[i];
                    if (((o >> a) & 1) == 0 || ((z >> b) & 1) == 0) {
                        continue;
                    } else if (((z >> a) & 1) == 0 || ((o >> b) & 1) == 0) {
                        unused[cei] = false;
                        State xz = ((z >> a) ^ (z >> b)) & 1;
                        State xo = ((o >> a) ^ (o >> b)) & 1;
                        z ^= (xz << a) | (xz << b);
                        o ^= (xo << a) | (xo << b);
                    } else {
                        unused[cei] = false;
                        State qz = z, qo = (o ^ ((State)1 << a) ^ ((State)1 << b));
                        z ^= ((State)1 << b);
                        stack.emplace_back(i + 1, qz, qo);
                    }
                }
                states_next.emplace_back(z, o);
            }
            size_t len_gen = states_next.size();
            assert(len_gen <= fib1[root_size]);
            sort(states_next.begin(), states_next.end());
            decltype(states_next)::iterator dedup_it = unique(states_next.begin(), states_next.end());
            states_next.erase(dedup_it, states_next.end());
            size_t len_dedup = states_next.size();
            states[root] = move(states_next);
            if (PROGRESS_THRESHOLD <= n) {
                cerr << "AppliedCE, size: " << root_size << ", len: " 
                << len_pre << "=>" << len_gen << "=>" << len_dedup
                << ", root: " << root << ", cmp: [";
                for (const auto [cei, a, b] : cmp_part) {
                    cerr << "(" << cei << "," << a << "," << b << "),";
                }
                cerr << "]" << endl;
            }
        }
    }
    for (auto queue : states) {
        State n1_mask = fullbit >> (bits - (n - 1));
        State q_mask = queue.empty() ? 0 : (queue.front().first | queue.front().second);
        unsorted_i |= (q_mask & (~q_mask >> 1)) & n1_mask;
        for (const auto [z, o] : queue) {
            unsorted_i |= (o & (z >> 1));
        }
    }
    // If unsorted contains true, it is not a sorting network
    if (unsorted_i != 0) {
        // Return the positions that may not be sorted
        vector<bool> unsorted(n - 1, false);
        for (int k = 0; k < n - 1; ++k) {
            unsorted[k] = ((unsorted_i >> k) & 1 != 0);
        }
        return unexpected(unsorted);
    }
    // Return the unused comparators
    return unused;
}

expected<vector<bool>, vector<bool>> is_sorting_network_low(const size_t n, const vector<pair<size_t, size_t>> cmps) {
    assert(2 <= n && n <= MAX_N && MAX_N <= UINT64_WIDTH && 10 <= MAX_N);
    for (auto [a, b] : cmps) {
        // 0-indexed
        assert(0 <= a && a < b && b < n);
    }
    size_t m = cmps.size();
    vector<bool> unused(m, true), unsorted(n - 1, false);
    array<array<uint64_t, 16>, MAX_N> states;
    const uint64_t z = UINT64_MAX;
    const array<uint64_t, 6> lows = {0xaaaaaaaaaaaaaaaa, 0xcccccccccccccccc, 0xf0f0f0f0f0f0f0f0, 0xff00ff00ff00ff00, 0xffff0000ffff0000, 0xffffffff00000000};
    uint64_t limit = 1ULL << (max(n, 10uz) - 10);
    for (uint64_t i = 0; i < limit; ++i) {
        for (size_t j = 0; auto x : lows) {
            states[j++].fill(x);
        }
        states[6] = { 0, z, 0, z, 0, z, 0, z, 0, z, 0, z, 0, z, 0, z };
        states[7] = { 0, 0, z, z, 0, 0, z, z, 0, 0, z, z, 0, 0, z, z };
        states[8] = { 0, 0, 0, 0, z, z, z, z, 0, 0, 0, 0, z, z, z, z };
        states[9] = { 0, 0, 0, 0, 0, 0, 0, 0, z, z, z, z, z, z, z, z };
        for (size_t j = 10; j < n; ++j) {
            states[j].fill(-((i >> (j - 10)) & 1));
        }
        for (size_t j = 0; auto [a, b] : cmps) {
            array<uint64_t, 16> &va = states[a], &vb = states[b];
            array<uint64_t, 16> &&na = { va[0] & vb[0], va[1] & vb[1], va[2] & vb[2], va[3] & vb[3], va[4] & vb[4], va[5] & vb[5], va[6] & vb[6], va[7] & vb[7], va[8] & vb[8], va[9] & vb[9], va[10] & vb[10], va[11] & vb[11], va[12] & vb[12], va[13] & vb[13], va[14] & vb[14], va[15] & vb[15] };
            array<uint64_t, 16> &&nb = { va[0] | vb[0], va[1] | vb[1], va[2] | vb[2], va[3] | vb[3], va[4] | vb[4], va[5] | vb[5], va[6] | vb[6], va[7] | vb[7], va[8] | vb[8], va[9] | vb[9], va[10] | vb[10], va[11] | vb[11], va[12] | vb[12], va[13] | vb[13], va[14] | vb[14], va[15] | vb[15] };
            if (va != na) {
                states[a] = na;
                states[b] = nb;
                unused[j] = false;
            }
            ++j;
        }
        for (size_t j = 1; j < n; ++j) {
            for (size_t k = 0; k < 16; ++k) {
                if ((states[j - 1][k] & ~states[j][k]) != 0) {
                    unsorted[j - 1] = true;
                    break;
                }
            }
        }
    }
    if (any_of(unsorted.begin(), unsorted.end(), [](bool x) { return x; })) {
        return unexpected(unsorted);
    }
    return unused;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int t;
    cin >> t;
    assert(1 <= t && t <= MAX_T);

    // φ = (1 + √5) / 2 : golden ratio 1.618033988749895
    double phi = sqrt(1.25) + 0.5;
    double cost = 0.0;

    for (int i = 0; i < t; ++i) {
        int n, m;
        cin >> n >> m;
        vector<pair<size_t, size_t>> cmps;
        vector<int> vec_a, vec_b;
        assert(MAX_N <= numeric_limits<size_t>::max());
        assert(2 <= n && n <= MAX_N && 1 <= m && m <= n * (n - 1) / 2);
        // Test case cost <= MAX_COST
        cost += m * pow(phi, n);
        assert(cost <= MAX_COST);
        // Read comparators
        for (size_t j = 0; j < m; ++j) {
            int a;
            cin >> a;
            vec_a.push_back(a);
        }
        for (size_t j = 0; j < m; ++j) {
            int b;
            cin >> b;
            vec_b.push_back(b);
        }
        for (size_t j = 0; j < m; ++j) {
            int a = vec_a[j], b = vec_b[j];
            assert(1 <= a && a < b && b <= n);
            // 1-indexed to 0-indexed
            cmps.emplace_back(a - 1, b - 1);
        }
        // Check if it is a sorting network
        auto is_sorting = n < 20 ? is_sorting_network_low(n, cmps) : is_sorting_network(n, cmps);
        if (is_sorting.has_value()) {
            auto unused = is_sorting.value();
            assert(unused.size() == m);
            cout << "Yes\n";
            // List the unused comparators j
            cout << count(unused.begin(), unused.end(), true) << '\n';
            bool first = true;
            // 1-indexed
            for (int j = 1; const auto x : unused) {
                if (x) {
                    if (!first) {
                        cout << ' ';
                    }
                    cout << j;
                    first = false;
                }
                j++;
            }
            cout << '\n';
        } else {
            auto unsorted = is_sorting.error();
            assert(unsorted.size() == n - 1);
            cout << "No\n";
            // List the positions k that may not be sorted
            cout << count(unsorted.begin(), unsorted.end(), true) << '\n';
            bool first = true;
            // 1-indexed
            for (int k = 1; const auto x : unsorted) {
                if (x) {
                    if (!first) {
                        cout << ' ';
                    }
                    cout << k;
                    first = false;
                }
                k++;
            }
            cout << '\n';
        }
    }

    return 0;
}
0