結果

問題 No.3442 Good Vertex Connectivity
コンテスト
ユーザー GOTKAKO
提出日時 2026-02-06 22:22:31
言語 C++17
(gcc 15.2.0 + boost 1.89.0)
結果
WA  
実行時間 -
コード長 16,794 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,454 ms
コンパイル使用メモリ 249,524 KB
実行使用メモリ 43,592 KB
最終ジャッジ日時 2026-02-06 22:23:29
合計ジャッジ時間 53,403 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 49 WA * 20
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>
using namespace std;

struct HLD1{
    //1.頂点のみパターン.
    int n = 0,tim = 0;
    vector<int> dist,in,out,siz,head,par,ord;
    HLD1(vector<vector<int>> &Graph,int Root = 0):n(Graph.size()),dist(n),in(n),out(n),siz(n),head(n),par(n),ord(n){
        iota(head.begin(),head.end(),0);
        auto dfs1 = [&](auto dfs1,int pos,int back,int d) -> int {
            par.at(pos) = back,dist.at(pos) = d;
            int maxsiz = 0,big = -1,idx = -1,ret = 1;
            for(auto to : Graph.at(pos)){
                idx++;
                if(to == back) continue;
                int kid = dfs1(dfs1,to,pos,d+1);
                ret += kid;
                if(maxsiz < kid) maxsiz = kid,big = idx;
            }
            if(big > -1) swap(Graph.at(pos).at(0),Graph.at(pos).at(big));
            return siz.at(pos) = ret;
        };
        int time = 0;
        auto dfs2 = [&](auto dfs2,int pos,int back) -> void {
            ord.at(time) = pos,in.at(pos) = time++;
            if(Graph.at(pos).size() > 1) head.at(Graph.at(pos).at(0)) = head.at(pos);
            for(auto to : Graph.at(pos)) if(to != back) dfs2(dfs2,to,pos);
            out.at(pos) = time;
        };
        dfs1(dfs1,Root,-1,0),dfs2(dfs2,Root,-1);
    } 
    vector<pair<int,int>> findpath(int u,int v){ //O(logN).
    //dfs行きがけ順に並べた頂点のセグ木の区間を返す.
    //行きがけ順はrep(0-n)give[in[i]]=A[i].
    //交換法則が成り立たない時は修正必須.
        vector<pair<int,int>> ret;
        while(head.at(u) != head.at(v)){
            if(dist.at(head.at(u)) > dist.at(head.at(v))) swap(u,v);
            ret.push_back({in.at(head.at(v)),in.at(v)+1});
            v = par.at(head.at(v));
        }
        if(in.at(u) > in.at(v)) swap(u,v);
        ret.push_back({in.at(u),in.at(v)+1});
        return ret;
    }
    pair<int,int> subtree(int u){return {in.at(u),out.at(u)};}
    int lca(int u,int v){ //O(logN).
        int hu = head.at(u),hv = head.at(v);
        while(hu != hv){
            if(dist.at(hu) < dist.at(hv)) v = par.at(hv),hv = head.at(v);
            else u = par.at(hu),hu = head.at(u); 
        }
        if(dist.at(u) <= dist.at(v)) return u;
        else return v;
    }
    int jump(int u,int v,int k){ //u->vパスでuからk個進んだ頂点 O(logN).
        int l = lca(u,v);
        if(k <= dist.at(u)-dist.at(l)) return la(u,k);
        k -= dist.at(u)-dist.at(l);
        if(k <= dist.at(v)-dist.at(l)) return la(v,dist.at(v)-dist.at(l)-k);
        return -1; //パス長<kなら-1.
    }
    int la(int u,int k){ //uからk個根に戻った時. O(logN).
        int hu = head.at(u);
        while(u != -1 && dist.at(u)-dist.at(hu) < k){
            k -= dist.at(u)-dist.at(hu)+1;
            u = par.at(hu);
            if(u != -1) hu = head.at(u);
        }
        if(u == -1) return -1;
        return ord.at(in.at(u)-k);
    }
    int distance(int u,int v){return dist.at(u)+dist.at(v)-2*dist.at(lca(u,v));}
};

using S1 = pair<int,int>;
using FF = int;
class LazySegmentTree{
    //ACL超参考にしてる というかパクリ.
    //verify十分だけど注意.
    private:
    vector<S1> dat;
    vector<FF> lazy;
    public:
    int siz = -1,n = -1,log = 0;
 
    S1 op(S1 a,S1 b){
        if(a.first == b.first) return {a.first,a.second+b.second};
        return min(a,b);
    }
    S1 mapping(FF f, S1 x){return {x.first+f,x.second};}
    FF composition(FF f, FF g){return f+g;}
    S1 e(){return {1001001001,0};}
    FF id(){return 0;}
    //op区間演算 mapping lazy→data composition lazy→lazy
    //e 単位元 id map(id,a)=a
 
    LazySegmentTree(int N){init(N);}
    LazySegmentTree(const vector<S1> &A){//配列サイズに合わせる.
        siz = 1; n = A.size(); log = 0;
        while(siz < n) siz <<= 1,log++;
        dat.resize(siz*2,e());
        lazy.resize(siz,id());
        for(int i=0; i<n; i++) dat.at(i+siz) = A.at(i);
        for(int i=siz-1; i>0; i--) merge(i);
    }
    void init(int N){ //単位元になる.
        siz = 1; n = N; log = 0;
        while(siz < n) siz *= 2,log++;
        dat.assign(siz*2,e());
        lazy.assign(siz,id());
    }
    void init(const vector<S1> &A){ //配列サイズに合わせる.
        siz = 1; n = A.size(); log = 0;
        while(siz < n) siz <<= 1,log++;
        dat.resize(siz*2,e());
        lazy.assign(siz,id());
        for(int i=0; i<n; i++) dat.at(i+siz) = A.at(i);
        for(int i=siz-1; i>0; i--) merge(i);
    }
 
    private:
    void eval(int u,FF f){
        //u番目にfを適用したあと保留.
        if(u == 0) return;
        dat.at(u) = mapping(f,dat.at(u));
        if(u < siz) lazy.at(u) = composition(f,lazy.at(u));
    }
    void spread(int u){ //uにあるFF保留を伝播.
        if(u == 0 || id() == lazy.at(u)) return;
        eval(2*u,lazy.at(u));
        eval(2*u+1,lazy.at(u));
        lazy.at(u) = id();
    }
    void merge(int u){dat.at(u) = op(dat.at(u*2),dat.at(u*2+1));} //子2つからマージ.
    public:
    void set(int pos,S1 x){ //1点変更.
        assert(0 <= pos && pos < n);
        pos += siz;
        for(int i=log; i>0; i--) spread(pos>>i);
        dat.at(pos) = x;
        while(pos > 1) pos >>= 1,merge(pos); 
    }
    void update(int pos,FF f){ //1点更新 変数抜かして区間更新になってないか注意!.
        assert(0 <= pos && pos < n);
        pos += siz;
        for(int i=log; i>0; i--) spread(pos>>i);
        dat.at(pos) = mapping(f,dat.at(pos));
        while(pos > 1) pos >>= 1,merge(pos);
    }
    void update(int l,int r,FF f){ //区間更新.
        assert(0 <= l && l <= r && r <= n);
        if(l == r) return;
        l += siz; r += siz;
        for(int i=log; i>0; i--){
            if(((l>>i)<<i) != l) spread(l>>i);
            if(((r>>i)<<i) != r) spread((r-1)>>i); 
        }
        int memoL = l,memoR = r;
        while(l < r){
            if(l&1) eval(l++,f);
            if(r&1) eval(--r,f);
            l >>= 1; r >>= 1;
        }
        l = memoL,r = memoR;
        while((l&1) == 0) l >>= 1;
        while((r&1) == 0) r >>= 1;
        r--; //-1注意.
        while(l > 1) l >>= 1,merge(l);
        while(r > 1) r >>= 1,merge(r); 
    }
 
    S1 get(int pos){ //1点取得.
        assert(0 <= pos && pos < n);
        pos += siz;
        for(int i=log; i>0; i--) spread(pos>>i);
        return dat.at(pos);
    }
    S1 rangeans(int l,int r){ //区間取得.
        assert(0 <= l && l <= r && r <= n);
        if(l == r) return e();
        l += siz; r += siz;
        for(int i=log; i>0; i--){
            if(((l>>i)<<i) != l) spread(l>>i);
            if(((r>>i)<<i) != r) spread((r-1)>>i); 
        }
 
        S1 retl = e(),retr = e();
        while(l < r){
            if(l&1) retl = op(retl,dat.at(l++));
            if(r&1) retr = op(dat.at(--r),retr);
            l >>= 1; r >>= 1;
        }
        return op(retl,retr);
    }
    S1 allrange(){return dat.at(1);} //全体取得.
 
    int maxright(const function<bool(S1)> f,int l = 0){
        assert(0 <= l && l <= n && f(e()));
        if(l == n) return n;
        l += siz;
        for(int i=log; i>0; i--) spread(l>>i);
        S1 now = e();
        do{
            while(l%2 == 0) l >>= 1;
            S1 next = op(now,dat.at(l));
            if(f(next) == false){
                while(l < siz){
                    spread(l); l <<= 1;
                    next = op(now,dat.at(l));
                    if(f(next)) now = next,l++;
                }
                return l-siz;
            }
            now = next; l++;
        }while((l&-l) != l);
        return n;
    }
    int minleft(const function<bool(S1)> f,int r = -1){
        if(r == -1) r = n;
        assert(0 <= r && r <= n && f(e()));
        if(r == 0) return 0;
        r += siz;
        for(int i=log; i>0; i--) spread((r-1)>>i);
        S1 now = e();
        do{
            r--;
            while(r&1) r >>= 1;
            if(r == 0) r = 1;
            S1 next = op(dat.at(r),now);
            if(f(next) == false){
                while(r < siz){
                    spread(r);
                    r <<= 1; r++;
                    next = op(now,dat.at(r));
                    if(f(next)) now = next,r--;
                }
                return r+1-siz;
            }
            now = next;
        }while((r&-r) != r);
        return 0;
    }
};

using SS = int;
class SegmentTree{
    public:
    int siz = -1,n = -1;
    vector<SS> dat;
 
    SS op(SS a, SS b){return a+b;}
    SS e(){return 0;}
    void renew (SS &a,SS x){
        a = op(a,x);
        //a = x; //set(pos,x)で可能.
        //その他.
    }
 
    SegmentTree(int N){init(N);}
    SegmentTree(const vector<SS> &A){//長さ配列サイズに合わせる.
        siz = 1; n = A.size();
        while(siz < n) siz *= 2;
        dat.resize(siz*2,e());
        for(int i=0; i<n; i++) dat.at(i+siz) = A.at(i);
        for(int i=siz-1; i>0; i--) dat.at(i) = op(dat.at(i*2),dat.at(i*2+1));
    }
    void init(int N){
        //全要素単位元に初期化.
        siz = 1; n = N;
        while(siz < n) siz *= 2;
        dat.assign(siz*2,e());
    }
    void init(const vector<SS> &A){//長さ配列サイズに合わせる.
        siz = 1; n = A.size();
        while(siz < n) siz *= 2;
        dat.resize(siz*2,e());
        for(int i=0; i<n; i++) dat.at(i+siz) = A.at(i);
        for(int i=siz-1; i>0; i--) dat.at(i) = op(dat.at(i*2),dat.at(i*2+1));
    }
    void set(int pos,SS x){
        pos = pos+siz;
        dat.at(pos) = x;
        while(pos != 1){
            pos = pos/2;
            dat.at(pos) = op(dat.at(pos*2),dat.at(pos*2+1));
        }
    }
    void update(int pos,SS x){
        pos = pos+siz;
        renew(dat.at(pos),x);
        while(pos != 1){
            pos = pos/2;
            dat.at(pos) = op(dat.at(pos*2),dat.at(pos*2+1));
        }
    } 
    SS findans(int l, int r){
        SS retl = e(),retr = e();
        l += siz,r += siz;
        while(l < r){
            if(l&1) retl = op(retl,dat.at(l++));
            if(r&1) retr = op(dat.at(--r),retr);
            l >>= 1; r >>= 1;
        }
        return op(retl,retr);
    }
    SS get(int pos){return dat.at(pos+siz);}
    SS rangeans(int l, int r){return findans(l,r);}
    SS allrange(){return dat.at(1);}
 
    //rightは) leftは[で 渡す&返す. 
    int maxright(const function<bool(SS)> f,int l = 0){
        //fを満たさない最小の箇所を返す なければn.
        l += siz; int r = n+siz;
        vector<int> ls,rs;
        while(l < r){
            if(l&1) ls.push_back(l++);
            if(r&1) rs.push_back(--r);
            l >>= 1; r >>= 1; 
        }
        SS okl = e();
        for(int i=0; i<ls.size(); i++){
            l = ls.at(i);
            SS now = op(okl,dat.at(l));
            if(!f(now)){
                while(l < siz){
                    l <<= 1;
                    now = op(okl,dat.at(l));
                    if(f(now)){okl = now; l++;}
                }
                return l-siz;
            } 
            okl = now;
        }
        for(int i=rs.size()-1; i>=0; i--){
            l = rs.at(i);
            SS now = op(okl,dat.at(l));
            if(!f(now)){
                while(l < siz){
                    l <<= 1;
                    now = op(okl,dat.at(l));
                    if(f(now)){okl = now; l++;}
                }
                return l-siz;
            } 
            okl = now;
        }
        return n;
    }
    int minleft(const function<bool(SS)> f,int r = -1){
        //fを満たす最小の箇所を返す なければ0.
        if(r == -1) r = n;
        int l = siz; r += siz;
        vector<int> ls,rs;
        while(l < r){
            if(l&1) ls.push_back(l++);
            if(r&1) rs.push_back(--r);
            l >>= 1; r >>= 1; 
        }
        SS okr = e();
        for(int i=0; i<rs.size(); i++){
            r = rs.at(i);
            SS now = op(dat.at(r),okr);
            if(!f(now)){
                while(r < siz){
                    r <<= 1; r++;
                    now = op(dat.at(r),okr);
                    if(f(now)){okr = now; r--;}
                }
                return r+1-siz;
            }
            okr = now;
        }
        for(int i=ls.size()-1; i>=0; i--){
            r = ls.at(i);
            SS now = op(dat.at(r),okr);
            if(!f(now)){
                while(r < siz){
                    r <<= 1; r++;
                    now = op(dat.at(r),okr);
                    if(f(now)){okr = now; r--;}
                }
                return r+1-siz;
            }
            okr = now;
        }
        return 0;
    }
};

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int N; cin >> N;
    vector<vector<int>> Graph(N);
    for(int i=0; i<N-1; i++){
        int u,v; cin >> u >> v;
        u--; v--;
        Graph.at(u).push_back(v);
        Graph.at(v).push_back(u);
    }
    vector<int> C(N);
    for(auto &c : C) cin >> c;

    HLD1 H(Graph);
    auto in = H.in,out = H.out,ord = H.ord,dist = H.dist;
    set<int> S;
    vector<S1> give(N);
    SegmentTree B(N);
    {
        auto dfs = [&](auto dfs,int pos,int back) -> int {
            int now = C.at(pos);
            if(C.at(pos) == 1) S.insert(in.at(pos)),B.update(in.at(pos),1);
            for(auto to : Graph.at(pos)) if(to != back) now += dfs(dfs,to,pos);
            give.at(in.at(pos)) = {now,1}; return now;
        };
        dfs(dfs,0,-1);
    }
    LazySegmentTree Z(give);
    int Q; cin >> Q;
    while(Q--){
        int t; cin >> t;
        if(t == 1){
            int v; cin >> v; v--;
            if(C.at(v)) S.erase(in.at(v)),B.update(in.at(v),-1);
            else S.insert(in.at(v)),B.update(in.at(v),1);
            C.at(v) ^= 1;
            for(auto [l,r] : H.findpath(0,v)) Z.update(l,r,C.at(v)*2-1);
        }
        else{
            int x,y; cin >> x >> y; x--,y--;
            if(S.size() == 0){cout << "0\n"; continue;}
            if(x == y){
                int lt = *S.begin(),rt = *S.rbegin();
                int lca = H.lca(ord.at(lt),ord.at(rt));
                auto [low,c] = Z.rangeans(in.at(lca),out.at(lca));
                int answer = out.at(lca)-in.at(lca);
                if(low == 0) answer -= c;
                cout << answer << "\n";
            }
            else{
                int d = H.distance(x,y);
                x = H.jump(x,y,d-1);
                if(dist.at(x) < dist.at(y)){
                    auto itr = S.lower_bound(in.at(y));
                    if(itr == S.end() || *itr >= out.at(y)){cout << "0\n"; continue;}
                    int lt = *itr;
                    itr = S.lower_bound(out.at(y));
                    int rt = *(--itr);
                    int lca = H.lca(ord.at(lt),ord.at(rt));
                    int answer = out.at(lca)-in.at(lca);
                    auto [low,c] = Z.rangeans(in.at(lca),out.at(lca));
                    if(low == 0) answer -= c;
                    cout << answer << "\n";
                }
                else{
                    int lt = *S.begin(),rt = *S.rbegin();
                    if(in.at(x) <= lt && lt < out.at(x)){
                        auto itr = S.lower_bound(out.at(lt));
                        if(itr == S.end()){
                            cout << "0\n"; continue;}
                        lt = *itr;
                    }
                    if(in.at(x) <= rt && rt < out.at(x)){
                        auto itr = S.lower_bound(in.at(x));
                        if(itr == S.begin()){cout << "0\n"; continue;}
                        rt = *(--itr);   
                    }
                    int lca = H.lca(ord.at(lt),ord.at(rt));
                    int l1 = in.at(lca),r1 = out.at(lca),l2 = in.at(x),r2 = out.at(x);

                    int del = B.rangeans(in.at(x),out.at(x));                    
                    if(del) for(auto [l,r] : H.findpath(0,x)) Z.update(l,r,-del);
                    if(l1 < l2 && r2 <= r1){
                        int answer = r1-l1+l2-r2;
                        auto [low1,c1] = Z.rangeans(l1,l2);
                        auto [low2,c2] = Z.rangeans(r2,r1);
                        if(low1 == low2) c1 += c2;
                        else if(low1 > low2) swap(low1,low2),swap(c1,c2);
                        if(low1 == 0) answer -= c1;
                        cout << answer << "\n";
                    }
                    else{
                        int answer = r1-l1;
                        auto [low1,c1] = Z.rangeans(l1,r1);
                        if(low1 == 0) answer -= c1;
                        cout << answer << "\n";
                    }
                    if(del) for(auto [l,r] : H.findpath(0,x)) Z.update(l,r,del);
                }
            }
        }
    }
}
0