結果
| 問題 |
No.3047 Verification of Sorting Network
|
| ユーザー |
👑 |
| 提出日時 | 2025-03-08 01:11:21 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 143 ms / 2,000 ms |
| コード長 | 18,048 bytes |
| コンパイル時間 | 5,382 ms |
| コンパイル使用メモリ | 332,060 KB |
| 実行使用メモリ | 8,612 KB |
| 最終ジャッジ日時 | 2025-03-08 01:11:35 |
| 合計ジャッジ時間 | 13,482 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 61 |
ソースコード
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
#pragma GCC target ("arch=x86-64-v3,tune=native")
#if __cplusplus < 202002L
#error C++20 or newer is required
#endif
// yukicoder No.3047 Verification of Sorting Network (thread pool version)
// https://yukicoder.me/problems/no/3047
using State = std::uint64_t;
using NodeIndex = std::uint8_t;
using CmpIndex = std::size_t;
static const bool SHOW_PROGRESS = true;
static const size_t PROGRESS_THRESHOLD_N = 24;
static const size_t STATE_BITS = std::numeric_limits<State>::digits;
static const size_t STATE_MAX = std::numeric_limits<State>::max();
static const size_t MAX_T = 100000;
static const size_t MAX_N = STATE_BITS;
static const double MAX_COST = 1e17;
// Calculate Fibonacci sequence values Fib1[0] = 1, Fib1[1] = 1, Fib1[n] = Fib1[n-1] + Fib1[n-2]
constexpr std::array<State, STATE_BITS + 1> fib1_gen(){
std::array<State, STATE_BITS + 1> fib1;
for (int i = 0; i <= STATE_BITS; ++i){
fib1[i] = i < 2 ? 1 : fib1[i - 1] + fib1[i - 2];
}
return fib1;
}
namespace threadpool {
class ThreadPool {
private:
typedef std::function<void()> Job;
std::vector<std::thread> workers;
std::queue<Job> jobQueue;
std::mutex queueMutex;
std::condition_variable condVar;
bool stopFlag;
void workerFunc() {
while(true) {
Job job;
{
std::unique_lock<std::mutex> lock(queueMutex);
condVar.wait(lock, [this]{return stopFlag || !jobQueue.empty();});
if(stopFlag && jobQueue.empty()) {
return;
}
job = std::move(jobQueue.front());
jobQueue.pop();
}
job();
}
}
public:
ThreadPool(size_t size) : stopFlag(false) {
for(size_t i = 0; i < size; i++) {
workers.emplace_back([this]{
workerFunc();
});
}
}
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queueMutex);
stopFlag = true;
}
condVar.notify_all();
for(auto &th : workers) {
if(th.joinable()) {
th.join();
}
}
}
void execute(Job f) {
{
std::unique_lock<std::mutex> lock(queueMutex);
jobQueue.push(std::move(f));
}
condVar.notify_one();
}
};
} // namespace threadpool
namespace sorting_network_check {
// Fibonacci sequence (FIB1[i] = FIB1[i-1] + FIB1[i-2]), expected to have 65 elements (STATE_BITS+1)
static const std::array<State, STATE_BITS+1> FIB1 = fib1_gen();
struct Job {
State z;
State o;
CmpIndex i;
};
// JobResult
struct JobResult {
State branches;
std::vector<bool> used;
std::array<State, STATE_BITS> unsorted; // Bit set in unsorted[i]
JobResult() = default;
explicit JobResult(size_t cmpSize) {
branches = 0;
used.resize(cmpSize, false);
for(auto &x : unsorted) x = 0ULL;
}
void merge(const JobResult &other) {
branches += other.branches;
for(size_t i = 0; i < used.size() && i < other.used.size(); i++){
used[i] = (used[i] || other.used[i]);
}
for(size_t i = 0; i < STATE_BITS; i++){
unsorted[i] |= other.unsorted[i];
}
}
std::vector<bool> get_unused() const {
std::vector<bool> ret(used.size(), false);
for(size_t i = 0; i < used.size(); i++){
ret[i] = !used[i];
}
return ret;
}
bool is_sorting_network() const {
for(size_t i = 0; i < STATE_BITS; i++){
if(unsorted[i] != 0ULL) return false;
}
return true;
}
std::array<State, STATE_BITS> get_unsorted_bitmap() const {
return unsorted;
}
// Get all unsorted pairs
std::vector<std::pair<size_t,size_t>> get_unsorted_allpairs() const {
std::vector<std::pair<size_t,size_t>> v;
for(size_t i = 0; i < STATE_BITS; i++){
State z = unsorted[i];
while(z != 0ULL){
unsigned long j = std::countr_zero(z);
v.push_back(std::make_pair(i,j));
z &= (z - 1ULL);
}
}
return v;
}
// Extract only adjacent pairs (example implementation)
std::vector<size_t> get_unsorted_adjacent() const {
std::vector<size_t> v;
for(size_t i = 0; i + 1 < STATE_BITS; i++){
// Equivalent to ((unsorted[i] >> i) & 2) != 0
// AND with 2 after i-bit shift
if( ((unsorted[i] >> i) & 2ULL) != 0ULL ){
v.push_back(i);
}
}
return v;
}
};
enum class SearchMode {
Split,
Single
};
// JobResultFuture
class JobResultFuture {
private:
size_t n;
// Using a mutex-protected queue instead of Rust's mpsc::Receiver
// Here, it is assumed that results are pushed to a thread-safe queue
std::shared_ptr<std::mutex> progressMutex;
std::shared_ptr<std::queue<JobResult>> progressQueue;
std::condition_variable* progressCondVar;
JobResult progress;
std::atomic<bool>* doneFlag;
public:
JobResultFuture(size_t n_,
JobResult initRes,
std::shared_ptr<std::mutex> mu,
std::shared_ptr<std::queue<JobResult>> qu,
std::condition_variable* cv,
std::atomic<bool>* done)
: n(n_),
progressMutex(mu),
progressQueue(qu),
progressCondVar(cv),
progress(initRes),
doneFlag(done)
{}
JobResult get_await() {
State fib1_n = FIB1[n];
State next_progress = PROGRESS_THRESHOLD_N <= n ? 0 : FIB1[n] + 1;
while(progress.branches < fib1_n) {
std::unique_lock<std::mutex> lk(*progressMutex);
progressCondVar->wait(lk, [this]{return !progressQueue->empty() || (*doneFlag);});
while(!progressQueue->empty()) {
JobResult r = progressQueue->front();
progressQueue->pop();
progress.merge(r);
}
lk.unlock();
if(SHOW_PROGRESS && progress.branches >= next_progress){
State percent = 0;
if(fib1_n != 0) {
percent = (progress.branches * 100ULL) / fib1_n;
}
std::cerr << "\rprogress: " << percent << "%" << std::flush;
next_progress = ((percent + 1ULL) * fib1_n - 1ULL) / 100ULL + 1ULL;
}
if((*doneFlag) && progressQueue->empty()){
break;
}
}
assert(fib1_n == progress.branches);
if(SHOW_PROGRESS && PROGRESS_THRESHOLD_N <= n){
std::cerr << std::endl;
}
return progress;
}
};
// execute_job definition
void execute_job(
std::shared_ptr<threadpool::ThreadPool> pool,
std::shared_ptr<std::mutex> progressMutex,
std::shared_ptr<std::queue<JobResult>> progressQueue,
std::condition_variable* progressCondVar,
std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
SearchMode mode,
uint32_t threshold,
Job job,
std::atomic<bool>* doneFlag
);
// Actual processing executed within the thread (execute_job main body)
inline void job_worker_func(
std::shared_ptr<threadpool::ThreadPool> pool,
std::shared_ptr<std::mutex> progressMutex,
std::shared_ptr<std::queue<JobResult>> progressQueue,
std::condition_variable* progressCondVar,
std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
SearchMode mode,
uint32_t threshold,
Job job,
std::atomic<bool>* doneFlag
) {
std::vector<Job> stack;
stack.reserve(MAX_N);
stack.push_back(job);
JobResult progressResult(cmp->size());
State child_branches = 0;
while(!stack.empty()){
auto [z, o, i] = stack.back();
stack.pop_back();
while(true){
if(i < cmp->size()){
auto [a, b] = (*cmp)[i];
i++;
if(((o >> a) & 1ULL) == 0ULL || ((z >> b) & 1ULL) == 0ULL){
continue;
}
if(((z >> a) & 1ULL) != 0ULL && ((o >> b) & 1ULL) != 0ULL){
progressResult.used[i - 1] = true;
auto qz = z;
auto qo = o;
qo &= ~((1ULL << a)) & ~((1ULL << b));
z &= ~(1ULL << b);
if(mode == SearchMode::Split){
uint32_t q_unknown_count = std::popcount(qz & qo);
if(q_unknown_count < threshold){
// Execute two more jobs here
execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{qz, qo, i}, doneFlag);
execute_job(pool, progressMutex, progressQueue, progressCondVar, cmp, SearchMode::Single, threshold, Job{z, o, i}, doneFlag);
child_branches += FIB1[q_unknown_count + 2ULL];
break;
}
}
// If not splitting
if((qo & (qz >> 1)) == 0) {
progressResult.branches += 1ULL;
} else {
stack.push_back(Job{qz, qo, i});
}
if((o & (z >> 1)) == 0) {
progressResult.branches += 1ULL;
break;
}
} else {
progressResult.used[i - 1] = true;
// Bitwise operations for xz and xo
State xz = ((z >> a) ^ (z >> b)) & 1ULL;
State xo = ((o >> a) ^ (o >> b)) & 1ULL;
z ^= (xz << a) | (xz << b);
o ^= (xo << a) | (xo << b);
if((o & (z >> 1)) == 0) {
progressResult.branches += 1ULL;
break;
}
}
} else {
// All comparisons are done
State np0 = (~z) | o;
State np1 = (~o) | z;
while(np0 != 0ULL) {
unsigned long j = std::countr_zero(np0);
State u = progressResult.unsorted[j] | (np1 & ((UINT64_MAX << 1ULL) << j));
progressResult.unsorted[j] = u;
np0 &= (np0 - 1ULL);
}
progressResult.branches += FIB1[std::popcount(z & o)];
break;
}
}
}
// Check
State testFib = FIB1[std::popcount(job.z & job.o)];
assert(progressResult.branches + child_branches == testFib);
// Send progress
{
std::unique_lock<std::mutex> lock(*progressMutex);
progressQueue->push(progressResult);
}
progressCondVar->notify_one();
}
// execute_job
void execute_job(
std::shared_ptr<threadpool::ThreadPool> pool,
std::shared_ptr<std::mutex> progressMutex,
std::shared_ptr<std::queue<JobResult>> progressQueue,
std::condition_variable* progressCondVar,
std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp,
SearchMode mode,
uint32_t threshold,
Job job,
std::atomic<bool>* doneFlag
) {
pool->execute([=]{
job_worker_func(pool, progressMutex, progressQueue, progressCondVar, cmp, mode, threshold, job, doneFlag);
});
}
// is_sorting_network_future
// Similar to a Rust function
JobResultFuture is_sorting_network_future(
std::shared_ptr<threadpool::ThreadPool> pool,
size_t n,
std::shared_ptr<std::vector<std::pair<NodeIndex, NodeIndex>>> cmp
) {
// Rustのassert
assert(n >= 2 && n <= MAX_N && n <= 255);
assert(cmp->size() <= std::numeric_limits<CmpIndex>::max());
for(auto &ab : (*cmp)) {
assert(ab.first < ab.second && ab.second < n);
}
// Something like a channel
auto progressQueue = std::make_shared<std::queue<JobResult>>();
auto progressMutex = std::make_shared<std::mutex>();
auto doneFlag = new std::atomic<bool>(false);
static std::condition_variable progressCondVar;
// Determine mode
SearchMode mode = SearchMode::Single;
uint32_t threshold = 0;
if(n >= 2 && n <= 22){
mode = SearchMode::Single;
} else {
// Use a threshold equivalent to (n/2 + 8).max(20) for Split
mode = SearchMode::Split;
threshold = std::max<uint32_t>(n/2 + 8, 20);
}
// Issue the base job
{
Job j;
// z = State::MAX >> (State::BITS - n as u32) → (1<<n)-1
j.z = (UINT64_MAX >> (64 - n));
j.o = UINT64_MAX; // State::MAX
j.i = 0;
execute_job(pool, progressMutex, progressQueue, &progressCondVar, cmp, mode, threshold, j, doneFlag);
}
JobResult initRes(cmp->size());
return JobResultFuture(n, initRes, progressMutex, progressQueue, &progressCondVar, doneFlag);
}
} // namespace sorting_network_check
int main(){
using namespace sorting_network_check;
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr);
auto execution_start = std::chrono::steady_clock::now();
size_t t;
std::cin >> t;
assert(t <= MAX_T);
// Approximate value of the golden ratio φ = (1 + √5) / 2 ≃ 1.618033988749895
double phi = std::sqrt(1.25) + 0.5;
double cost = 0.0;
// Number of threads
unsigned int number_childworker_thread = std::thread::hardware_concurrency();
if(number_childworker_thread == 0) number_childworker_thread = 1;
auto pool = std::make_shared<threadpool::ThreadPool>(number_childworker_thread);
std::vector<sorting_network_check::JobResultFuture> futures;
futures.reserve(t);
for(size_t _i = 0; _i < t; _i++){
size_t n, m;
std::cin >> n >> m;
assert(n >= 2 && n <= MAX_N);
assert(m >= 1);
cost += double(m) * std::pow(phi, double(n));
assert(cost <= MAX_COST);
// Comparator information
std::vector<size_t> a(m), b(m);
for(size_t j = 0; j < m; j++){
std::cin >> a[j];
}
for(size_t j = 0; j < m; j++){
std::cin >> b[j];
}
for(size_t j = 0; j < m; j++){
assert(a[j] >= 1 && a[j] <= n);
assert(b[j] >= 1 && b[j] <= n);
}
// Convert to 0-indexed
std::vector<std::pair<NodeIndex, NodeIndex>> cmp;
cmp.reserve(m);
for(size_t j = 0; j < m; j++){
NodeIndex first = static_cast<NodeIndex>(a[j] - 1);
NodeIndex second= static_cast<NodeIndex>(b[j] - 1);
assert(first < second);
cmp.push_back(std::make_pair(first, second));
}
auto sharedCmp = std::make_shared<std::vector<std::pair<NodeIndex, NodeIndex>>>(std::move(cmp));
// Check
futures.push_back(is_sorting_network_future(pool, n, sharedCmp));
}
// Result processing
for(auto &f : futures){
auto result = f.get_await();
if(result.is_sorting_network()){
auto unused = result.get_unused();
std::cout << "Yes\n";
size_t cnt = 0;
for(bool u : unused) if(u) cnt++;
std::cout << cnt << "\n";
// Enumerate in 1-indexed
std::vector<size_t> idxs;
idxs.reserve(cnt);
for(size_t i = 0; i < unused.size(); i++){
if(unused[i]){
idxs.push_back(i + 1);
}
}
for(size_t i = 0; i < idxs.size(); i++){
std::cout << idxs[i] << ( (i+1 < idxs.size()) ? " " : "" );
}
std::cout << '\n';
} else {
std::cout << "No\n";
auto unsorted_bitmap = result.get_unsorted_bitmap();
// Enumerate adjacent unsorted positions
std::vector<size_t> adj;
for(size_t k = 0; k+1 < MAX_N; k++){
if( ((unsorted_bitmap[k] >> k) & 2ULL) != 0ULL ){
adj.push_back(k);
}
}
std::cout << adj.size() << "\n";
for(size_t i = 0; i < adj.size(); i++){
std::cout << (adj[i] + 1) << ( (i+1 < adj.size()) ? ' ' : '\n' );
}
}
}
std::fflush(stdout);
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - execution_start
).count();
std::cerr << ms << " ms\n";
return 0;
}