結果

問題 No.2977 Kth Xor Pair
ユーザー ooaiu
提出日時 2024-12-01 17:41:03
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
RE  
実行時間 -
コード長 6,445 bytes
コンパイル時間 3,194 ms
コンパイル使用メモリ 248,684 KB
実行使用メモリ 16,128 KB
最終ジャッジ日時 2024-12-01 17:41:41
合計ジャッジ時間 31,726 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 15 RE * 18 TLE * 1
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using u8 = uint8_t;
using u16 = uint16_t;
using u32 = uint32_t;
using u64 = uint64_t;
using i128 = __int128;
using u128 = unsigned __int128;
using f128 = __float128;

int topbit(int x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); }
int topbit(u32 x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); }
int topbit(ll x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); }
int topbit(u64 x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); }

// https://maspypy.github.io/library/ds/binary_trie.hpp
// 非永続ならば、2 * 要素数 のノード数
template <int LOG, bool PERSISTENT, int NODES, typename UINT = u64, typename SIZE_TYPE = int>
struct Binary_Trie {
    using T = SIZE_TYPE;
    struct Node {
        int width;
        UINT val;
        T cnt;
        Node *l, *r;
    };

    Node *pool;
    int pid;
    using np = Node *;

    Binary_Trie() : pid(0) { pool = new Node[NODES]; }

    void reset() { pid = 0; }

    np new_root() { return nullptr; }

    np add(np root, UINT val, T cnt = 1) {
        if (!root) root = new_node(0, 0);
        assert(0 <= val && val < (1LL << LOG));
        return add_rec(root, LOG, val, cnt);
    }

    // f(val, cnt)
    template <typename F>
    void enumerate(np root, F f) {
        auto dfs = [&](auto &dfs, np root, UINT val, int ht) -> void {
            if (ht == 0) {
                f(val, root->cnt);
                return;
            }
            np c = root->l;
            if (c) {
                dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
            }
            c = root->r;
            if (c) {
                dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
            }
        };
        if (root) dfs(dfs, root, 0, LOG);
    }

    // xor_val したあとの値で昇順 k 番目
    UINT kth(np root, T k, UINT xor_val) {
        assert(root && 0 <= k && k < root->cnt);
        return kth_rec(root, 0, k, LOG, xor_val) ^ xor_val;
    }

    // xor_val したあとの値で最小値
    UINT min(np root, UINT xor_val) {
        assert(root && root->cnt);
        return kth(root, 0, xor_val);
    }

    // xor_val したあとの値で最大値
    UINT max(np root, UINT xor_val) {
        assert(root && root->cnt);
        return kth(root, (root->cnt) - 1, xor_val);
    }

    // xor_val したあとの値で [0, upper) 内に入るものの個数
    T prefix_count(np root, UINT upper, UINT xor_val) {
        if (!root) return 0;
        return prefix_count_rec(root, LOG, upper, xor_val, 0);
    }

    // xor_val したあとの値で [lo, hi) 内に入るものの個数
    T count(np root, UINT lo, UINT hi, UINT xor_val) {
        return prefix_count(root, hi, xor_val) - prefix_count(root, lo, xor_val);
    }

   private:
    inline UINT mask(int k) { return (UINT(1) << k) - 1; }

    np new_node(int width, UINT val) {
        pool[pid].l = pool[pid].r = nullptr;
        pool[pid].width = width;
        pool[pid].val = val;
        pool[pid].cnt = 0;
        return &(pool[pid++]);
    }

    np copy_node(np c) {
        if (!c || !PERSISTENT) return c;
        np res = &(pool[pid++]);
        res->width = c->width, res->val = c->val;
        res->cnt = c->cnt, res->l = c->l, res->r = c->r;
        return res;
    }

    np add_rec(np root, int ht, UINT val, T cnt) {
        root = copy_node(root);
        root->cnt += cnt;
        if (ht == 0) return root;

        bool go_r = (val >> (ht - 1)) & 1;
        np c = (go_r ? root->r : root->l);
        if (!c) {
            c = new_node(ht, val);
            c->cnt = cnt;
            if (!go_r) root->l = c;
            if (go_r) root->r = c;
            return root;
        }
        int w = c->width;
        if ((val >> (ht - w)) == c->val) {
            c = add_rec(c, ht - w, val & mask(ht - w), cnt);
            if (!go_r) root->l = c;
            if (go_r) root->r = c;
            return root;
        }
        int same = w - 1 - topbit((val >> (ht - w)) ^ (c->val));
        np n = new_node(same, (c->val) >> (w - same));
        n->cnt = c->cnt + cnt;
        c = copy_node(c);
        c->width = w - same;
        c->val = c->val & mask(w - same);
        if ((val >> (ht - same - 1)) & 1) {
            n->l = c;
            n->r = new_node(ht - same, val & mask(ht - same));
            n->r->cnt = cnt;
        } else {
            n->r = c;
            n->l = new_node(ht - same, val & mask(ht - same));
            n->l->cnt = cnt;
        }
        if (!go_r) root->l = n;
        if (go_r) root->r = n;
        return root;
    }

    UINT kth_rec(np root, UINT val, T k, int ht, UINT xor_val) {
        if (ht == 0) return val;
        np left = root->l, right = root->r;
        if ((xor_val >> (ht - 1)) & 1) swap(left, right);
        T sl = (left ? left->cnt : 0);
        np c;
        if (k < sl) {
            c = left;
        }
        if (k >= sl) {
            c = right, k -= sl;
        }
        int w = c->width;
        return kth_rec(c, val << w | (c->val), k, ht - w, xor_val);
    }

    T prefix_count_rec(np root, int ht, UINT LIM, UINT xor_val, UINT val) {
        UINT now = (val << ht) ^ (xor_val);
        if ((LIM >> ht) > (now >> ht)) return root->cnt;
        if (ht == 0 || (LIM >> ht) < (now >> ht)) return 0;
        T res = 0;
        for (int k = 0; k < 2; k++) {
            np c = (k == 0 ? root->l : root->r);
            if (c) {
                int w = c->width;
                res += prefix_count_rec(c, ht - w, LIM, xor_val, val << w | c->val);
            }
        }
        return res;
    }
};

void solve() {
    int N;
    int64_t K;
    cin >> N >> K;
    vector<int> A(N);
    for (int i = 0; i < N; i++) cin >> A[i];
    Binary_Trie<40, false, 300000, u64, u64> trie;
    const auto check = [&](int64_t x) -> bool {
        trie.reset();
        int64_t res = 0;
        auto root = trie.new_root();
        for (int i = 0; i < N; i++) {
            res += trie.prefix_count(root, x + 1, A[i]);
            root = trie.add(root, A[i]);
        }
        return res >= K;
    };
    int64_t lo = -1, hi = 1LL << 32;
    while (hi - lo > 1) {
        int64_t mi = (lo + hi) / 2;
        (check(mi) ? hi : lo) = mi;
    }
    cout << hi << endl;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int tt = 1;
    // std::cin >> tt;
    while (tt--) {
        solve();
    }
}
0