結果

問題 No.3047 Verification of Sorting Network
ユーザー 👑 Mizar
提出日時 2025-03-08 01:11:21
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 143 ms / 2,000 ms
コード長 18,048 bytes
コンパイル時間 5,382 ms
コンパイル使用メモリ 332,060 KB
実行使用メモリ 8,612 KB
最終ジャッジ日時 2025-03-08 01:11:35
合計ジャッジ時間 13,482 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
#pragma GCC target ("arch=x86-64-v3,tune=native")
#if __cplusplus < 202002L
#error C++20 or newer is required
#endif

// yukicoder No.3047 Verification of Sorting Network (thread pool version)
// https://yukicoder.me/problems/no/3047

using State = std::uint64_t;
using NodeIndex = std::uint8_t;
using CmpIndex = std::size_t;

static const bool SHOW_PROGRESS = true;
static const size_t PROGRESS_THRESHOLD_N = 24;
static const size_t STATE_BITS = std::numeric_limits<State>::digits;
static const size_t STATE_MAX = std::numeric_limits<State>::max();
static const size_t MAX_T = 100000;
static const size_t MAX_N = STATE_BITS;
static const double MAX_COST = 1e17;

// Calculate Fibonacci sequence values Fib1[0] = 1, Fib1[1] = 1, Fib1[n] = Fib1[n-1] + Fib1[n-2]
constexpr std::array<State, STATE_BITS + 1> fib1_gen(){
    std::array<State, STATE_BITS + 1> fib1;
    for (int i = 0; i <= STATE_BITS; ++i){
        fib1[i] = i < 2 ? 1 : fib1[i - 1] + fib1[i - 2];
    }
    return fib1;
}

namespace threadpool {
    class ThreadPool {
    private:
        typedef std::function<void()> Job;
        std::vector<std::thread> workers;
        std::queue<Job> jobQueue;
        std::mutex queueMutex;
        std::condition_variable condVar;
        bool stopFlag;

        void workerFunc() {
            while(true) {
                Job job;
                {
                    std::unique_lock<std::mutex> lock(queueMutex);
                    condVar.wait(lock, [this]{return stopFlag || !jobQueue.empty();});
                    if(stopFlag && jobQueue.empty()) {
                        return;
                    }
                    job = std::move(jobQueue.front());
                    jobQueue.pop();
                }
                job();
            }
        }

    public:
        ThreadPool(size_t size) : stopFlag(false) {
            for(size_t i = 0; i < size; i++) {
                workers.emplace_back([this]{
                    workerFunc();
                });
            }
        }
        ~ThreadPool() {
            {
                std::unique_lock<std::mutex> lock(queueMutex);
                stopFlag = true;
            }
            condVar.notify_all();
            for(auto &th : workers) {
                if(th.joinable()) {
                    th.join();
                }
            }
        }

        void execute(Job f) {
            {
                std::unique_lock<std::mutex> lock(queueMutex);
                jobQueue.push(std::move(f));
            }
            condVar.notify_one();
        }
    };
} // namespace threadpool

namespace sorting_network_check {

    // Fibonacci sequence (FIB1[i] = FIB1[i-1] + FIB1[i-2]), expected to have 65 elements (STATE_BITS+1)
    static const std::array<State, STATE_BITS+1> FIB1 = fib1_gen();

    struct Job {
        State z;
        State o;
        CmpIndex i;
    };

    // JobResult
    struct JobResult {
        State branches;
        std::vector<bool> used;
        std::array<State, STATE_BITS> unsorted; // Bit set in unsorted[i]

        JobResult() = default;
        explicit JobResult(size_t cmpSize) {
            branches = 0;
            used.resize(cmpSize, false);
            for(auto &x : unsorted) x = 0ULL;
        }

        void merge(const JobResult &other) {
            branches += other.branches;
            for(size_t i = 0; i < used.size() && i < other.used.size(); i++){
                used[i] = (used[i] || other.used[i]);
            }
            for(size_t i = 0; i < STATE_BITS; i++){
                unsorted[i] |= other.unsorted[i];
            }
        }

        std::vector<bool> get_unused() const {
            std::vector<bool> ret(used.size(), false);
            for(size_t i = 0; i < used.size(); i++){
                ret[i] = !used[i];
            }
            return ret;
        }

        bool is_sorting_network() const {
            for(size_t i = 0; i < STATE_BITS; i++){
                if(unsorted[i] != 0ULL) return false;
            }
            return true;
        }

        std::array<State, STATE_BITS> get_unsorted_bitmap() const {
            return unsorted;
        }

        // Get all unsorted pairs
        std::vector<std::pair<size_t,size_t>> get_unsorted_allpairs() const {
            std::vector<std::pair<size_t,size_t>> v;
            for(size_t i = 0; i < STATE_BITS; i++){
                State z = unsorted[i];
                while(z != 0ULL){
                    unsigned long j = std::countr_zero(z);
                    v.push_back(std::make_pair(i,j));
                    z &= (z - 1ULL);
                }
            }
            return v;
        }

        // Extract only adjacent pairs (example implementation)
        std::vector<size_t> get_unsorted_adjacent() const {
            std::vector<size_t> v;
            for(size_t i = 0; i + 1 < STATE_BITS; i++){
                // Equivalent to ((unsorted[i] >> i) & 2) != 0
                // AND with 2 after i-bit shift
                if( ((unsorted[i] >> i) & 2ULL) != 0ULL ){
                    v.push_back(i);
                }
            }
            return v;
        }
    };

    enum class SearchMode {
        Split,
        Single
    };

    // JobResultFuture
    class JobResultFuture {
    private:
        size_t n;
        // Using a mutex-protected queue instead of Rust's mpsc::Receiver
        // Here, it is assumed that results are pushed to a thread-safe queue
        std::shared_ptr<std::mutex> progressMutex;
        std::shared_ptr<std::queue<JobResult>> progressQueue;
        std::condition_variable* progressCondVar;
        JobResult progress;
        std::atomic<bool>* doneFlag;

    public:
        JobResultFuture(size_t n_,
                        JobResult initRes,
                        std::shared_ptr<std::mutex> mu,
                        std::shared_ptr<std::queue<JobResult>> qu,
                        std::condition_variable* cv,
                        std::atomic<bool>* done)
            : n(n_),
              progressMutex(mu),
              progressQueue(qu),
              progressCondVar(cv),
              progress(initRes),
              doneFlag(done)
        {}

        JobResult get_await() {
            State fib1_n = FIB1[n];
            State next_progress = PROGRESS_THRESHOLD_N <= n ? 0 : FIB1[n] + 1;
            while(progress.branches < fib1_n) {
                std::unique_lock<std::mutex> lk(*progressMutex);
                progressCondVar->wait(lk, [this]{return !progressQueue->empty() || (*doneFlag);});
                while(!progressQueue->empty()) {
                    JobResult r = progressQueue->front();
                    progressQueue->pop();
                    progress.merge(r);
                }
                lk.unlock();

                if(SHOW_PROGRESS && progress.branches >= next_progress){
                    State percent = 0;
                    if(fib1_n != 0) {
                        percent = (progress.branches * 100ULL) / fib1_n;
                    }
                    std::cerr << "\rprogress: " << percent << "%" << std::flush;
                    next_progress = ((percent + 1ULL) * fib1_n - 1ULL) / 100ULL + 1ULL;
                }
                if((*doneFlag) && progressQueue->empty()){
                    break;
                }
            }
            assert(fib1_n == progress.branches);
            if(SHOW_PROGRESS && PROGRESS_THRESHOLD_N <= n){
                std::cerr << std::endl;
            }
            return progress;
        }
    };

    // execute_job definition
    void execute_job(
        std::shared_ptr<threadpool::ThreadPool> pool,
        std::shared_ptr<std::mutex> progressMutex,
        std::shared_ptr<std::queue<JobResult>> progressQueue,
        std::condition_variable* progressCondVar,
        std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
        SearchMode mode,
        uint32_t threshold,
        Job job,
        std::atomic<bool>* doneFlag
    );

    // Actual processing executed within the thread (execute_job main body)
    inline void job_worker_func(
        std::shared_ptr<threadpool::ThreadPool> pool,
        std::shared_ptr<std::mutex> progressMutex,
        std::shared_ptr<std::queue<JobResult>> progressQueue,
        std::condition_variable* progressCondVar,
        std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
        SearchMode mode,
        uint32_t threshold,
        Job job,
        std::atomic<bool>* doneFlag
    ) {
        std::vector<Job> stack;
        stack.reserve(MAX_N);
        stack.push_back(job);

        JobResult progressResult(cmp->size());
        State child_branches = 0;
        while(!stack.empty()){
            auto [z, o, i] = stack.back();
            stack.pop_back();
            while(true){
                if(i < cmp->size()){
                    auto [a, b] = (*cmp)[i];
                    i++;
                    if(((o >> a) & 1ULL) == 0ULL || ((z >> b) & 1ULL) == 0ULL){
                        continue;
                    }
                    if(((z >> a) & 1ULL) != 0ULL && ((o >> b) & 1ULL) != 0ULL){
                        progressResult.used[i - 1] = true;

                        auto qz = z;
                        auto qo = o;
                        qo &= ~((1ULL << a)) & ~((1ULL << b));
                        z  &= ~(1ULL << b);

                        if(mode == SearchMode::Split){
                            uint32_t q_unknown_count = std::popcount(qz & qo);
                            if(q_unknown_count < threshold){
                                // Execute two more jobs here
                                execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{qz, qo, i}, doneFlag);
                                execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{z,  o,  i}, doneFlag);
                                child_branches += FIB1[q_unknown_count + 2ULL];
                                break;
                            }
                        }
                        // If not splitting
                        if((qo & (qz >> 1)) == 0) {
                            progressResult.branches += 1ULL;
                        } else {
                            stack.push_back(Job{qz, qo, i});
                        }
                        if((o & (z >> 1)) == 0) {
                            progressResult.branches += 1ULL;
                            break;
                        }
                    } else {
                        progressResult.used[i - 1] = true;
                        // Bitwise operations for xz and xo
                        State xz = ((z >> a) ^ (z >> b)) & 1ULL;
                        State xo = ((o >> a) ^ (o >> b)) & 1ULL;
                        z ^= (xz << a) | (xz << b);
                        o ^= (xo << a) | (xo << b);
                        if((o & (z >> 1)) == 0) {
                            progressResult.branches += 1ULL;
                            break;
                        }
                    }
                } else {
                    // All comparisons are done
                    State np0 = (~z) | o;
                    State np1 = (~o) | z;
                    while(np0 != 0ULL) {
                        unsigned long j = std::countr_zero(np0);
                        State u = progressResult.unsorted[j] | (np1 & ((UINT64_MAX << 1ULL) << j));
                        progressResult.unsorted[j] = u;
                        np0 &= (np0 - 1ULL);
                    }
                    progressResult.branches += FIB1[std::popcount(z & o)];
                    break;
                }
            }
        }

        // Check
        State testFib = FIB1[std::popcount(job.z & job.o)];
        assert(progressResult.branches + child_branches == testFib);

        // Send progress
        {
            std::unique_lock<std::mutex> lock(*progressMutex);
            progressQueue->push(progressResult);
        }
        progressCondVar->notify_one();
    }

    // execute_job
    void execute_job(
        std::shared_ptr<threadpool::ThreadPool> pool,
        std::shared_ptr<std::mutex> progressMutex,
        std::shared_ptr<std::queue<JobResult>> progressQueue,
        std::condition_variable* progressCondVar,
        std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
        SearchMode mode,
        uint32_t threshold,
        Job job,
        std::atomic<bool>* doneFlag
    ) {
        pool->execute([=]{
            job_worker_func(pool, progressMutex, progressQueue, progressCondVar, cmp, mode, threshold, job, doneFlag);
        });
    }

    // is_sorting_network_future
    // Similar to a Rust function
    JobResultFuture is_sorting_network_future(
        std::shared_ptr<threadpool::ThreadPool> pool,
        size_t n,
        std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp
    ) {
        // Rustのassert
        assert(n >= 2 && n <= MAX_N && n <= 255);
        assert(cmp->size() <= std::numeric_limits<CmpIndex>::max());
        for(auto &ab : (*cmp)) {
            assert(ab.first < ab.second && ab.second < n);
        }
        // Something like a channel
        auto progressQueue = std::make_shared<std::queue<JobResult>>();
        auto progressMutex = std::make_shared<std::mutex>();
        auto doneFlag = new std::atomic<bool>(false);
        static std::condition_variable progressCondVar;

        // Determine mode
        SearchMode mode = SearchMode::Single;
        uint32_t threshold = 0;
        if(n >= 2 && n <= 22){
            mode = SearchMode::Single;
        } else {
            // Use a threshold equivalent to (n/2 + 8).max(20) for Split
            mode = SearchMode::Split;
            threshold = std::max<uint32_t>(n/2 + 8, 20);
        }

        // Issue the base job
        {
            Job j;
            // z = State::MAX >> (State::BITS - n as u32) → (1<<n)-1
            j.z = (UINT64_MAX >> (64 - n));
            j.o = UINT64_MAX; // State::MAX
            j.i = 0;
            execute_job(pool, progressMutex, progressQueue, &progressCondVar, cmp, mode, threshold, j, doneFlag);
        }
        JobResult initRes(cmp->size());

        return JobResultFuture(n, initRes, progressMutex, progressQueue, &progressCondVar, doneFlag);
    }
} // namespace sorting_network_check

int main(){
    using namespace sorting_network_check;
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr);
    auto execution_start = std::chrono::steady_clock::now();

    size_t t;
    std::cin >> t;
    assert(t <= MAX_T);

    // Approximate value of the golden ratio φ = (1 + √5) / 2 ≃ 1.618033988749895
    double phi = std::sqrt(1.25) + 0.5;
    double cost = 0.0;

    // Number of threads
    unsigned int number_childworker_thread = std::thread::hardware_concurrency();
    if(number_childworker_thread == 0) number_childworker_thread = 1;
    auto pool = std::make_shared<threadpool::ThreadPool>(number_childworker_thread);
    std::vector<sorting_network_check::JobResultFuture> futures;
    futures.reserve(t);

    for(size_t _i = 0; _i < t; _i++){
        size_t n, m;
        std::cin >> n >> m;
        assert(n >= 2 && n <= MAX_N);
        assert(m >= 1);

        cost += double(m) * std::pow(phi, double(n));
        assert(cost <= MAX_COST);

        // Comparator information
        std::vector<size_t> a(m), b(m);
        for(size_t j = 0; j < m; j++){
            std::cin >> a[j];
        }
        for(size_t j = 0; j < m; j++){
            std::cin >> b[j];
        }
        for(size_t j = 0; j < m; j++){
            assert(a[j] >= 1 && a[j] <= n);
            assert(b[j] >= 1 && b[j] <= n);
        }

        // Convert to 0-indexed
        std::vector<std::pair<NodeIndex, NodeIndex>> cmp;
        cmp.reserve(m);
        for(size_t j = 0; j < m; j++){
            NodeIndex first = static_cast<NodeIndex>(a[j] - 1);
            NodeIndex second= static_cast<NodeIndex>(b[j] - 1);
            assert(first < second);
            cmp.push_back(std::make_pair(first, second));
        }

        auto sharedCmp = std::make_shared<std::vector<std::pair<NodeIndex, NodeIndex>>>(std::move(cmp));
        // Check
        futures.push_back(is_sorting_network_future(pool, n, sharedCmp));
    }

    // Result processing
    for(auto &f : futures){
        auto result = f.get_await();
        if(result.is_sorting_network()){
            auto unused = result.get_unused();
            std::cout << "Yes\n";
            size_t cnt = 0;
            for(bool u : unused) if(u) cnt++;
            std::cout << cnt << "\n";
            // Enumerate in 1-indexed
            std::vector<size_t> idxs;
            idxs.reserve(cnt);
            for(size_t i = 0; i < unused.size(); i++){
                if(unused[i]){
                    idxs.push_back(i + 1);
                }
            }
            for(size_t i = 0; i < idxs.size(); i++){
                std::cout << idxs[i] << ( (i+1 < idxs.size()) ? " " : "" );
            }
            std::cout << '\n';
        } else {
            std::cout << "No\n";
            auto unsorted_bitmap = result.get_unsorted_bitmap();
            // Enumerate adjacent unsorted positions
            std::vector<size_t> adj;
            for(size_t k = 0; k+1 < MAX_N; k++){
                if( ((unsorted_bitmap[k] >> k) & 2ULL) != 0ULL ){
                    adj.push_back(k);
                }
            }
            std::cout << adj.size() << "\n";
            for(size_t i = 0; i < adj.size(); i++){
                std::cout << (adj[i] + 1) << ( (i+1 < adj.size()) ? ' ' : '\n' );
            }
        }
    }

    std::fflush(stdout);
    auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
        std::chrono::steady_clock::now() - execution_start
    ).count();
    std::cerr << ms << " ms\n";

    return 0;
}
0