結果

問題 No.3189 Semifinal Stage
ユーザー zawakasu
提出日時 2025-06-22 16:49:25
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 3,247 ms / 4,000 ms
コード長 14,125 bytes
コンパイル時間 2,830 ms
コンパイル使用メモリ 179,860 KB
実行使用メモリ 50,372 KB
最終ジャッジ日時 2025-06-22 16:50:22
合計ジャッジ時間 46,333 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <iomanip>
#include <cassert>
#include <vector>
#include <algorithm>
#include <utility>
#include <numeric>
#include <tuple>
#include <ranges>
namespace ranges = std::ranges;
namespace views = std::views;
// #include "Src/Number/IntegerDivision.hpp"
// #include "Src/Utility/BinarySearch.hpp"
// #include "Src/Sequence/CompressedSequence.hpp"
// #include "Src/Sequence/RunLengthEncoding.hpp"
// #include "Src/Algebra/Group/AdditiveGroup.hpp"
// #include "Src/DataStructure/FenwickTree/FenwickTree.hpp"
// #include "Src/DataStructure/SegmentTree/SegmentTree.hpp"
// #include "Src/DataStructure/DisjointSetUnion/DisjointSetUnion.hpp"



#include <cstdint>
#include <cstddef>

namespace zawa {

using i16 = std::int16_t;
using i32 = std::int32_t;
using i64 = std::int64_t;
using i128 = __int128_t;

using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;
using u64 = std::uint64_t;

using usize = std::size_t;

} // namespace zawa


#include <optional>

namespace zawa {

template <class T, class U>
class ChminMonoidData {
private:
    std::optional<T> priority_{};
    U value_{};
public:
    ChminMonoidData() = default;
    ChminMonoidData(const U& value)
        : priority_{std::nullopt}, value_{value} {}
    ChminMonoidData(const T& priority, const U& value)
        : priority_{priority}, value_{value} {}

    constexpr bool infty() const noexcept {
        return !priority_.has_value();
    }
    constexpr const T& priority() const noexcept {
        return priority_.value();
    }
    constexpr const U& value() const noexcept {
        return value_;
    }
    friend constexpr bool operator<(const ChminMonoidData& l, const ChminMonoidData& r) {
        if (l.infty()) return false;
        else if (r.infty()) return true;
        else return l.priority() < r.priority();
    }
};

template <class T, class U>
struct ChminMonoid {
    using Element = ChminMonoidData<T, U>;
    static Element identity() noexcept {
        return Element{};
    }
    // タイブレークはl側を優先するようになっている。
    static Element operation(const Element& l, const Element& r) noexcept {
        return (r < l ? r : l);
    }
};

} // namespace zawa


#include <ostream>

namespace zawa {

template <class Structure>
class SparseTable {
private:
    using Value = typename Structure::Element;
    std::vector<u32> L;
    std::vector<std::vector<Value>> dat;
public:

    SparseTable() : L{}, dat{} {}
    SparseTable(const std::vector<Value>& a) : L(a.size() + 1), dat{} {
        for (u32 i{1} ; i < L.size() ; i++) {
            L[i] = L[i - 1] + (i >> (L[i - 1] + 1));
        }
        dat.resize(L.back() + 1);
        dat[0] = a;
        for (u32 i{1}, len{2} ; i < dat.size() ; i++, len <<= 1) {
            dat[i] = dat[i - 1];
            for (u32 j{} ; j + len - 1 < dat[i].size() ; j++) {
                dat[i][j] = Structure::operation(dat[i - 1][j], dat[i - 1][j + (len >> 1)]);
            }
        }
    }

    Value product(u32 l, u32 r) const {
        assert(l <= r);
        assert(l < dat[0].size());
        assert(r <= dat[0].size());
        u32 now{L[r - l]};
        return Structure::operation(dat[now][l], dat[now][r - (1 << now)]);
    }

    friend std::ostream& operator<<(std::ostream& os, const SparseTable<Structure>& spt) {
        for (u32 i{}, len{1} ; i < spt.dat.size() ; i++, len <<= 1) {
            os << "length = " << len << '\n';
            for (u32 j{} ; j + len - 1 < spt.dat[i].size() ; j++) {
                os << spt.dat[i][j] << (j + len == spt.dat[i].size() ? '\n' : ' ');
            }
        }
        return os;
    }
};

} // namespace zawa


namespace zawa {

template <class V>
class LowestCommonAncestor {
private:
    using Monoid = ChminMonoid<u32, V>;

public:
    LowestCommonAncestor() = default;

    LowestCommonAncestor(const std::vector<std::vector<V>>& tree, V r = V{}) 
        : n_{tree.size()}, depth_(tree.size()), L_(tree.size()), R_(tree.size()), st_{} {
            std::vector<typename Monoid::Element> init;
            init.reserve(2 * size());
            auto dfs{[&](auto dfs, V v, V p) -> void {
                depth_[v] = (p == INVALID ? 0u : depth_[p] + 1);
                L_[v] = (u32)init.size();
                for (auto x : tree[v]) {
                    if (x == p) {
                        continue;
                    }
                    init.emplace_back(depth_[v], v);
                    dfs(dfs, x, v);
                }
                R_[v] = (u32)init.size();
            }};
            dfs(dfs, r, INVALID);
            st_ = SparseTable<Monoid>(init);
    }

    V operator()(V u, V v) const {
        assert(verify(u));
        assert(verify(v));
        if (L_[u] > L_[v]) {
            std::swap(u, v);
        }
        return u == v ? u : st_.product(L_[u], R_[v]).value();
    }

    V lca(V u, V v) const {
        return (*this)(u, v);
    }

    inline u32 depth(V v) const noexcept {
        assert(verify(v));
        return depth_[v];
    }

    u32 distance(V u, V v) const {
        assert(verify(u));
        assert(verify(v));
        return depth(u) + depth(v) - 2u * depth((*this)(u, v));
    }

    bool isAncestor(V p, V v) const {
        assert(verify(p));
        assert(verify(v));
        return L_[p] <= L_[v] and R_[v] <= R_[p];
    }

protected:
    u32 left(V v) const noexcept {
        return L_[v];
    }

    inline usize size() const {
        return n_;
    }

    inline bool verify(V v) const {
        return v < (V)size();
    }

private:
    static constexpr V INVALID{static_cast<V>(-1)};
    usize n_{};
    std::vector<u32> depth_, L_, R_;
    SparseTable<Monoid> st_;
};

} // namespace zawa


namespace zawa {

template <class V>
class AuxiliaryTree : public LowestCommonAncestor<V> {
public:
    using Super = LowestCommonAncestor<V>;

    AuxiliaryTree() = default;
    AuxiliaryTree(const std::vector<std::vector<V>>& T, V r = 0u) 
        : Super{ T, r }, T_(T.size()), dist_(T.size()), used_(T.size()) {}

    V construct(const std::vector<V>& vs) {
        assert(vs.size());
        clear();
        vs_ = vs;
        return build();
    }

    const std::vector<V>& operator[](V v) const {
        assert(Super::verify(v));
        return T_[v];
    }

    inline bool contains(V v) const {
        assert(Super::verify(v));
        return used_[v];
    }

    inline u32 parentEdgeLength(V v) const {
        assert(contains(v));
        return dist_[v];
    }

    std::vector<V> current() const {
        return vs_;
    }

private:
    std::vector<std::vector<V>> T_{}; 
    std::vector<V> vs_{};
    std::vector<u32> dist_{};
    std::vector<bool> used_{};

    void addEdge(V p, V v) {
        assert(Super::depth(p) < Super::depth(v));
        T_[p].push_back(v);
        T_[v].push_back(p);
        dist_[v] = Super::depth(v) - Super::depth(p);
    }

    V build() {
        std::sort(vs_.begin(), vs_.end(), [&](V u, V v) -> bool {
                return Super::left(u) < Super::left(v);
                });
        vs_.erase(std::unique(vs_.begin(), vs_.end()), vs_.end());
        usize k{vs_.size()};
        std::vector<V> stack;
        stack.reserve(2u * vs_.size());
        stack.emplace_back(vs_[0]);
        for (usize i{} ; i + 1 < k ; i++) {
            if (!Super::isAncestor(vs_[i], vs_[i + 1])) {
                V w{Super::lca(vs_[i], vs_[i + 1])};
                V l{stack.back()};
                stack.pop_back();
                while (stack.size() and LowestCommonAncestor<V>::depth(w) < LowestCommonAncestor<V>::depth(stack.back())) {
                    addEdge(stack.back(), l);
                    l = stack.back();
                    stack.pop_back();
                }
                if (stack.empty() or stack.back() != w) {
                    stack.emplace_back(w);
                    vs_.emplace_back(w);
                }
                addEdge(w, l);
            }
            stack.emplace_back(vs_[i + 1]);
        }
        while (stack.size() > 1u) {
            V l{stack.back()};
            stack.pop_back();
            addEdge(stack.back(), l);
        }
        for (V v : vs_) {
            used_[v] = true;
        }
        return stack.back();
    }

    void clear() {
        for (V v : vs_) {
            T_[v].clear();
            used_[v] = false;
            dist_[v] = 0u;
        }
        vs_.clear();
    }
};

} // namespace zawa
using namespace zawa;
// #include "atcoder/modint"
// using mint = atcoder::modint998244353;

const int SQ = 500;
int N, Q, T[100010], V[100010];
std::vector<int> g[100010];
std::vector<int> naive() {
    std::vector<int> col(N), res;
    for (int i = 0 ; i < Q ; i++) {
        if (T[i] == 1) col[V[i]] ^= 1;
        else {
            std::vector<int> que, dist(N, -1);
            que.push_back(V[i]);
            dist[V[i]] = 0; 
            int ans = (int)1e9;
            for (int t = 0 ; t < std::ssize(que) ; t++) {
                const int v = que[t];
                if (col[v]) ans = std::min(ans, dist[v]);
                for (int x : g[v]) if (dist[x] == -1) {
                    dist[x] = dist[v] + 1;
                    que.push_back(x);
                }
            }
            res.push_back(ans);
        }
    }
    return res;
}
std::vector<int> solve() {
    AuxiliaryTree at{std::vector(g, g + N)};
    std::vector<bool> black(N);
    std::vector<bool> mark(N);
    std::vector<int> res;
    res.reserve(Q);
    for (int i = 0 ; i < Q ; i += SQ) {
        std::vector<int> vs;
        vs.reserve(SQ);
        for (int j = 0 ; j < SQ and i + j < Q ; j++) if (!mark[V[i + j]]) {
            vs.push_back(V[i + j]);
            mark[V[i + j]] = true;
        }
        // for (int j = 0 ; j < N ; j++) std::cout << '(' << mark[j] << ',' << black[j] << ')' << ' ';
        // std::cout << std::endl;
        const auto root = at.construct(vs);
        const int INF = (int)1e9;
        std::vector<int> dist1(N, INF), dist(N, INF);
        auto rec1 = [&](auto rec, int v, int p) -> int {
            if (!mark[v] and black[v]) dist1[v] = 0;
            for (int x : g[v]) if (x != p) dist1[v] = std::min(dist1[v], rec(rec, x, v) + 1);
            return dist1[v];
        };
        rec1(rec1, 0, -1);
        auto rec2 = [&](auto rec, int v, int p, int pval) -> void {
            dist[v] = std::min(pval + 1, dist1[v]);
            int fir = INF, sec = INF;
            for (int x : g[v]) {
                const int val = 1 + (x == p ? pval : dist1[x]);
                if (fir > val) {
                    sec = fir;
                    fir = val;
                }
                else if (sec > val) {
                    sec = val;
                }
            }
            for (int x : g[v]) if (x != p) {
                const int val = dist1[x] + 1;
                const int propa = !mark[v] and black[v] ? 0 :
                    (val == fir ? sec : fir);
                rec(rec, x, v, propa);
            }
        };
        rec2(rec2, 0, -1, INF);
        std::vector<int> atdep(N, -1);
        auto dfs = [&](auto dfs, int v, int p, int d) -> void {
            atdep[v] = d;
            for (auto x : at[v]) if (x != p) dfs(dfs, x, v, d + 1);
        };
        dfs(dfs, root, -1, 0);
        for (int j = 0 ; j < SQ and i + j < Q ; j++) {
            mark[V[i + j]] = false;
            if (T[i + j] == 1) black[V[i + j]] = !black[V[i + j]];
            else if (T[i + j] == 2) {
                const int s = V[i + j];
                int ans = dist[s];
                std::vector<std::tuple<int, int, int>> que;
                que.push_back({s, -1, 0});
                for (int qt = 0 ; qt < std::ssize(que) ; qt++) {
                    const auto [v, p, d] = que[qt];
                    if (black[v]) {
                        ans = std::min(ans, d);
                    }
                    else {
                        for (auto x : at[v]) if (x != p) {
                            const auto w = at.parentEdgeLength(atdep[x] > atdep[v] ? x : v);
                            que.push_back({x, v, d + w});
                        }
                    }
                }
                res.push_back(ans);
            }
            else assert(false);
        }
    }
    return res;
}
#include <random>
int main() {
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    std::ios::sync_with_stdio(false);
#ifdef DEBUG
    std::mt19937 mt{std::random_device{}()};
    while (true) {
        static int testcase = 0;
        std::cerr << "------------------" << testcase++ << "------------------------------" << std::endl;
        N = mt() % 100 + 2; 
        for (int i = 0 ; i < N ; i++) g[i].clear();
        std::cout << N << std::endl;
        for (int i = 1 ; i < N ; i++) {
            int u = mt() % i, v = i;
            std::cout << u + 1 << ' ' << v + 1 << std::endl;
            g[u].push_back(v);
            g[v].push_back(u);
        }
        Q = mt() % 100 + 1;
        int cnt = 0;
        std::vector<int> blk(N);
        std::cout << Q << '\n';
        for (int i = 0 ; i < Q ; i++) {
            if (cnt == 0 or mt() % 3 == 0) {
                T[i] = 1;
                V[i] = mt() % N;
                cnt += blk[V[i]] == 0 ? 1 : -1;
                blk[V[i]] ^= 1;
            }
            else {
                T[i] = 2;
                V[i] = mt() % N;
            }
            std::cout << T[i] << ' ' << V[i] + 1 << std::endl;
        }
        auto my = solve(), ans = naive();
        if (my != ans) {
            for (int i : my) std::cout << i << ' ';
            std::cout << std::endl;
            for (int i : ans) std::cout << i << ' ';
            std::cout << std::endl;
            std::exit(0);
        }
    } 
#else
    std::cin >> N;
    for (int i = 0 ; i < N - 1 ; i++) {
        int u, v;
        std::cin >> u >> v;
        u--; v--;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    std::cin >> Q;
    for (int i = 0 ; i < Q ; i++) {
        std::cin >> T[i] >> V[i];
        V[i]--;
    }
    for (int ans : solve()) std::cout << ans << '\n';
#endif
}
0