結果
問題 |
No.3047 Verification of Sorting Network
|
ユーザー |
👑 |
提出日時 | 2025-03-08 01:11:21 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 143 ms / 2,000 ms |
コード長 | 18,048 bytes |
コンパイル時間 | 5,382 ms |
コンパイル使用メモリ | 332,060 KB |
実行使用メモリ | 8,612 KB |
最終ジャッジ日時 | 2025-03-08 01:11:35 |
合計ジャッジ時間 | 13,482 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
ソースコード
#include <bits/stdc++.h> #pragma GCC optimize ("O3") #pragma GCC target ("arch=x86-64-v3,tune=native") #if __cplusplus < 202002L #error C++20 or newer is required #endif // yukicoder No.3047 Verification of Sorting Network (thread pool version) // https://yukicoder.me/problems/no/3047 using State = std::uint64_t; using NodeIndex = std::uint8_t; using CmpIndex = std::size_t; static const bool SHOW_PROGRESS = true; static const size_t PROGRESS_THRESHOLD_N = 24; static const size_t STATE_BITS = std::numeric_limits<State>::digits; static const size_t STATE_MAX = std::numeric_limits<State>::max(); static const size_t MAX_T = 100000; static const size_t MAX_N = STATE_BITS; static const double MAX_COST = 1e17; // Calculate Fibonacci sequence values Fib1[0] = 1, Fib1[1] = 1, Fib1[n] = Fib1[n-1] + Fib1[n-2] constexpr std::array<State, STATE_BITS + 1> fib1_gen(){ std::array<State, STATE_BITS + 1> fib1; for (int i = 0; i <= STATE_BITS; ++i){ fib1[i] = i < 2 ? 1 : fib1[i - 1] + fib1[i - 2]; } return fib1; } namespace threadpool { class ThreadPool { private: typedef std::function<void()> Job; std::vector<std::thread> workers; std::queue<Job> jobQueue; std::mutex queueMutex; std::condition_variable condVar; bool stopFlag; void workerFunc() { while(true) { Job job; { std::unique_lock<std::mutex> lock(queueMutex); condVar.wait(lock, [this]{return stopFlag || !jobQueue.empty();}); if(stopFlag && jobQueue.empty()) { return; } job = std::move(jobQueue.front()); jobQueue.pop(); } job(); } } public: ThreadPool(size_t size) : stopFlag(false) { for(size_t i = 0; i < size; i++) { workers.emplace_back([this]{ workerFunc(); }); } } ~ThreadPool() { { std::unique_lock<std::mutex> lock(queueMutex); stopFlag = true; } condVar.notify_all(); for(auto &th : workers) { if(th.joinable()) { th.join(); } } } void execute(Job f) { { std::unique_lock<std::mutex> lock(queueMutex); jobQueue.push(std::move(f)); } condVar.notify_one(); } }; } // namespace threadpool namespace sorting_network_check { // Fibonacci sequence (FIB1[i] = FIB1[i-1] + FIB1[i-2]), expected to have 65 elements (STATE_BITS+1) static const std::array<State, STATE_BITS+1> FIB1 = fib1_gen(); struct Job { State z; State o; CmpIndex i; }; // JobResult struct JobResult { State branches; std::vector<bool> used; std::array<State, STATE_BITS> unsorted; // Bit set in unsorted[i] JobResult() = default; explicit JobResult(size_t cmpSize) { branches = 0; used.resize(cmpSize, false); for(auto &x : unsorted) x = 0ULL; } void merge(const JobResult &other) { branches += other.branches; for(size_t i = 0; i < used.size() && i < other.used.size(); i++){ used[i] = (used[i] || other.used[i]); } for(size_t i = 0; i < STATE_BITS; i++){ unsorted[i] |= other.unsorted[i]; } } std::vector<bool> get_unused() const { std::vector<bool> ret(used.size(), false); for(size_t i = 0; i < used.size(); i++){ ret[i] = !used[i]; } return ret; } bool is_sorting_network() const { for(size_t i = 0; i < STATE_BITS; i++){ if(unsorted[i] != 0ULL) return false; } return true; } std::array<State, STATE_BITS> get_unsorted_bitmap() const { return unsorted; } // Get all unsorted pairs std::vector<std::pair<size_t,size_t>> get_unsorted_allpairs() const { std::vector<std::pair<size_t,size_t>> v; for(size_t i = 0; i < STATE_BITS; i++){ State z = unsorted[i]; while(z != 0ULL){ unsigned long j = std::countr_zero(z); v.push_back(std::make_pair(i,j)); z &= (z - 1ULL); } } return v; } // Extract only adjacent pairs (example implementation) std::vector<size_t> get_unsorted_adjacent() const { std::vector<size_t> v; for(size_t i = 0; i + 1 < STATE_BITS; i++){ // Equivalent to ((unsorted[i] >> i) & 2) != 0 // AND with 2 after i-bit shift if( ((unsorted[i] >> i) & 2ULL) != 0ULL ){ v.push_back(i); } } return v; } }; enum class SearchMode { Split, Single }; // JobResultFuture class JobResultFuture { private: size_t n; // Using a mutex-protected queue instead of Rust's mpsc::Receiver // Here, it is assumed that results are pushed to a thread-safe queue std::shared_ptr<std::mutex> progressMutex; std::shared_ptr<std::queue<JobResult>> progressQueue; std::condition_variable* progressCondVar; JobResult progress; std::atomic<bool>* doneFlag; public: JobResultFuture(size_t n_, JobResult initRes, std::shared_ptr<std::mutex> mu, std::shared_ptr<std::queue<JobResult>> qu, std::condition_variable* cv, std::atomic<bool>* done) : n(n_), progressMutex(mu), progressQueue(qu), progressCondVar(cv), progress(initRes), doneFlag(done) {} JobResult get_await() { State fib1_n = FIB1[n]; State next_progress = PROGRESS_THRESHOLD_N <= n ? 0 : FIB1[n] + 1; while(progress.branches < fib1_n) { std::unique_lock<std::mutex> lk(*progressMutex); progressCondVar->wait(lk, [this]{return !progressQueue->empty() || (*doneFlag);}); while(!progressQueue->empty()) { JobResult r = progressQueue->front(); progressQueue->pop(); progress.merge(r); } lk.unlock(); if(SHOW_PROGRESS && progress.branches >= next_progress){ State percent = 0; if(fib1_n != 0) { percent = (progress.branches * 100ULL) / fib1_n; } std::cerr << "\rprogress: " << percent << "%" << std::flush; next_progress = ((percent + 1ULL) * fib1_n - 1ULL) / 100ULL + 1ULL; } if((*doneFlag) && progressQueue->empty()){ break; } } assert(fib1_n == progress.branches); if(SHOW_PROGRESS && PROGRESS_THRESHOLD_N <= n){ std::cerr << std::endl; } return progress; } }; // execute_job definition void execute_job( std::shared_ptr<threadpool::ThreadPool> pool, std::shared_ptr<std::mutex> progressMutex, std::shared_ptr<std::queue<JobResult>> progressQueue, std::condition_variable* progressCondVar, std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp, SearchMode mode, uint32_t threshold, Job job, std::atomic<bool>* doneFlag ); // Actual processing executed within the thread (execute_job main body) inline void job_worker_func( std::shared_ptr<threadpool::ThreadPool> pool, std::shared_ptr<std::mutex> progressMutex, std::shared_ptr<std::queue<JobResult>> progressQueue, std::condition_variable* progressCondVar, std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp, SearchMode mode, uint32_t threshold, Job job, std::atomic<bool>* doneFlag ) { std::vector<Job> stack; stack.reserve(MAX_N); stack.push_back(job); JobResult progressResult(cmp->size()); State child_branches = 0; while(!stack.empty()){ auto [z, o, i] = stack.back(); stack.pop_back(); while(true){ if(i < cmp->size()){ auto [a, b] = (*cmp)[i]; i++; if(((o >> a) & 1ULL) == 0ULL || ((z >> b) & 1ULL) == 0ULL){ continue; } if(((z >> a) & 1ULL) != 0ULL && ((o >> b) & 1ULL) != 0ULL){ progressResult.used[i - 1] = true; auto qz = z; auto qo = o; qo &= ~((1ULL << a)) & ~((1ULL << b)); z &= ~(1ULL << b); if(mode == SearchMode::Split){ uint32_t q_unknown_count = std::popcount(qz & qo); if(q_unknown_count < threshold){ // Execute two more jobs here execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{qz, qo, i}, doneFlag); execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{z, o, i}, doneFlag); child_branches += FIB1[q_unknown_count + 2ULL]; break; } } // If not splitting if((qo & (qz >> 1)) == 0) { progressResult.branches += 1ULL; } else { stack.push_back(Job{qz, qo, i}); } if((o & (z >> 1)) == 0) { progressResult.branches += 1ULL; break; } } else { progressResult.used[i - 1] = true; // Bitwise operations for xz and xo State xz = ((z >> a) ^ (z >> b)) & 1ULL; State xo = ((o >> a) ^ (o >> b)) & 1ULL; z ^= (xz << a) | (xz << b); o ^= (xo << a) | (xo << b); if((o & (z >> 1)) == 0) { progressResult.branches += 1ULL; break; } } } else { // All comparisons are done State np0 = (~z) | o; State np1 = (~o) | z; while(np0 != 0ULL) { unsigned long j = std::countr_zero(np0); State u = progressResult.unsorted[j] | (np1 & ((UINT64_MAX << 1ULL) << j)); progressResult.unsorted[j] = u; np0 &= (np0 - 1ULL); } progressResult.branches += FIB1[std::popcount(z & o)]; break; } } } // Check State testFib = FIB1[std::popcount(job.z & job.o)]; assert(progressResult.branches + child_branches == testFib); // Send progress { std::unique_lock<std::mutex> lock(*progressMutex); progressQueue->push(progressResult); } progressCondVar->notify_one(); } // execute_job void execute_job( std::shared_ptr<threadpool::ThreadPool> pool, std::shared_ptr<std::mutex> progressMutex, std::shared_ptr<std::queue<JobResult>> progressQueue, std::condition_variable* progressCondVar, std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp, SearchMode mode, uint32_t threshold, Job job, std::atomic<bool>* doneFlag ) { pool->execute([=]{ job_worker_func(pool, progressMutex, progressQueue, progressCondVar, cmp, mode, threshold, job, doneFlag); }); } // is_sorting_network_future // Similar to a Rust function JobResultFuture is_sorting_network_future( std::shared_ptr<threadpool::ThreadPool> pool, size_t n, std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp ) { // Rustのassert assert(n >= 2 && n <= MAX_N && n <= 255); assert(cmp->size() <= std::numeric_limits<CmpIndex>::max()); for(auto &ab : (*cmp)) { assert(ab.first < ab.second && ab.second < n); } // Something like a channel auto progressQueue = std::make_shared<std::queue<JobResult>>(); auto progressMutex = std::make_shared<std::mutex>(); auto doneFlag = new std::atomic<bool>(false); static std::condition_variable progressCondVar; // Determine mode SearchMode mode = SearchMode::Single; uint32_t threshold = 0; if(n >= 2 && n <= 22){ mode = SearchMode::Single; } else { // Use a threshold equivalent to (n/2 + 8).max(20) for Split mode = SearchMode::Split; threshold = std::max<uint32_t>(n/2 + 8, 20); } // Issue the base job { Job j; // z = State::MAX >> (State::BITS - n as u32) → (1<<n)-1 j.z = (UINT64_MAX >> (64 - n)); j.o = UINT64_MAX; // State::MAX j.i = 0; execute_job(pool, progressMutex, progressQueue, &progressCondVar, cmp, mode, threshold, j, doneFlag); } JobResult initRes(cmp->size()); return JobResultFuture(n, initRes, progressMutex, progressQueue, &progressCondVar, doneFlag); } } // namespace sorting_network_check int main(){ using namespace sorting_network_check; std::ios_base::sync_with_stdio(false); std::cin.tie(nullptr); auto execution_start = std::chrono::steady_clock::now(); size_t t; std::cin >> t; assert(t <= MAX_T); // Approximate value of the golden ratio φ = (1 + √5) / 2 ≃ 1.618033988749895 double phi = std::sqrt(1.25) + 0.5; double cost = 0.0; // Number of threads unsigned int number_childworker_thread = std::thread::hardware_concurrency(); if(number_childworker_thread == 0) number_childworker_thread = 1; auto pool = std::make_shared<threadpool::ThreadPool>(number_childworker_thread); std::vector<sorting_network_check::JobResultFuture> futures; futures.reserve(t); for(size_t _i = 0; _i < t; _i++){ size_t n, m; std::cin >> n >> m; assert(n >= 2 && n <= MAX_N); assert(m >= 1); cost += double(m) * std::pow(phi, double(n)); assert(cost <= MAX_COST); // Comparator information std::vector<size_t> a(m), b(m); for(size_t j = 0; j < m; j++){ std::cin >> a[j]; } for(size_t j = 0; j < m; j++){ std::cin >> b[j]; } for(size_t j = 0; j < m; j++){ assert(a[j] >= 1 && a[j] <= n); assert(b[j] >= 1 && b[j] <= n); } // Convert to 0-indexed std::vector<std::pair<NodeIndex, NodeIndex>> cmp; cmp.reserve(m); for(size_t j = 0; j < m; j++){ NodeIndex first = static_cast<NodeIndex>(a[j] - 1); NodeIndex second= static_cast<NodeIndex>(b[j] - 1); assert(first < second); cmp.push_back(std::make_pair(first, second)); } auto sharedCmp = std::make_shared<std::vector<std::pair<NodeIndex, NodeIndex>>>(std::move(cmp)); // Check futures.push_back(is_sorting_network_future(pool, n, sharedCmp)); } // Result processing for(auto &f : futures){ auto result = f.get_await(); if(result.is_sorting_network()){ auto unused = result.get_unused(); std::cout << "Yes\n"; size_t cnt = 0; for(bool u : unused) if(u) cnt++; std::cout << cnt << "\n"; // Enumerate in 1-indexed std::vector<size_t> idxs; idxs.reserve(cnt); for(size_t i = 0; i < unused.size(); i++){ if(unused[i]){ idxs.push_back(i + 1); } } for(size_t i = 0; i < idxs.size(); i++){ std::cout << idxs[i] << ( (i+1 < idxs.size()) ? " " : "" ); } std::cout << '\n'; } else { std::cout << "No\n"; auto unsorted_bitmap = result.get_unsorted_bitmap(); // Enumerate adjacent unsorted positions std::vector<size_t> adj; for(size_t k = 0; k+1 < MAX_N; k++){ if( ((unsorted_bitmap[k] >> k) & 2ULL) != 0ULL ){ adj.push_back(k); } } std::cout << adj.size() << "\n"; for(size_t i = 0; i < adj.size(); i++){ std::cout << (adj[i] + 1) << ( (i+1 < adj.size()) ? ' ' : '\n' ); } } } std::fflush(stdout); auto ms = std::chrono::duration_cast<std::chrono::milliseconds>( std::chrono::steady_clock::now() - execution_start ).count(); std::cerr << ms << " ms\n"; return 0; }