結果

問題 No.2163 LCA Sum Query
ユーザー tko919tko919
提出日時 2022-12-21 04:44:51
言語 C++17
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 580 ms / 6,000 ms
コード長 11,525 bytes
コンパイル時間 4,046 ms
コンパイル使用メモリ 224,708 KB
実行使用メモリ 15,260 KB
最終ジャッジ日時 2023-08-11 10:41:46
合計ジャッジ時間 19,332 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,388 KB
testcase_01 AC 1 ms
4,384 KB
testcase_02 AC 2 ms
4,384 KB
testcase_03 AC 2 ms
4,380 KB
testcase_04 AC 2 ms
4,384 KB
testcase_05 AC 2 ms
4,384 KB
testcase_06 AC 2 ms
4,380 KB
testcase_07 AC 2 ms
4,384 KB
testcase_08 AC 2 ms
4,516 KB
testcase_09 AC 2 ms
4,388 KB
testcase_10 AC 2 ms
4,380 KB
testcase_11 AC 2 ms
4,384 KB
testcase_12 AC 150 ms
4,420 KB
testcase_13 AC 199 ms
11,824 KB
testcase_14 AC 143 ms
12,676 KB
testcase_15 AC 60 ms
4,384 KB
testcase_16 AC 124 ms
11,976 KB
testcase_17 AC 140 ms
7,524 KB
testcase_18 AC 100 ms
4,384 KB
testcase_19 AC 72 ms
4,384 KB
testcase_20 AC 13 ms
7,824 KB
testcase_21 AC 69 ms
7,536 KB
testcase_22 AC 406 ms
12,932 KB
testcase_23 AC 260 ms
13,072 KB
testcase_24 AC 354 ms
13,044 KB
testcase_25 AC 268 ms
12,988 KB
testcase_26 AC 559 ms
12,944 KB
testcase_27 AC 364 ms
13,000 KB
testcase_28 AC 580 ms
13,112 KB
testcase_29 AC 362 ms
12,948 KB
testcase_30 AC 148 ms
15,112 KB
testcase_31 AC 166 ms
15,260 KB
testcase_32 AC 164 ms
14,340 KB
testcase_33 AC 166 ms
14,892 KB
testcase_34 AC 196 ms
13,228 KB
testcase_35 AC 149 ms
13,244 KB
testcase_36 AC 208 ms
13,116 KB
testcase_37 AC 140 ms
13,120 KB
testcase_38 AC 217 ms
13,676 KB
testcase_39 AC 170 ms
13,728 KB
testcase_40 AC 217 ms
13,672 KB
testcase_41 AC 167 ms
13,764 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#line 1 "library/Template/template.hpp"
#include <bits/stdc++.h>
using namespace std;

#define rep(i,a,b) for(int i=(int)(a);i<(int)(b);i++)
#define ALL(v) (v).begin(),(v).end()
using ll=long long int;
const int inf = 0x3fffffff;
const ll INF = 0x1fffffffffffffff;
template<typename T>inline bool chmax(T& a,T b){if(a<b){a=b;return 1;}return 0;}
template<typename T>inline bool chmin(T& a,T b){if(a>b){a=b;return 1;}return 0;}
#line 2 "library/Utility/fastio.hpp"
#include <unistd.h>

class FastIO{
    static constexpr int L=1<<16;
    char rdbuf[L];
    int rdLeft=0,rdRight=0;
    inline void reload(){
        int len=rdRight-rdLeft;
        memmove(rdbuf,rdbuf+rdLeft,len);
        rdLeft=0,rdRight=len;
        rdRight+=fread(rdbuf+len,1,L-len,stdin);
    }
    inline bool skip(){
        for(;;){
            while(rdLeft!=rdRight and rdbuf[rdLeft]<=' ')rdLeft++;
            if(rdLeft==rdRight){
                reload();
                if(rdLeft==rdRight)return false;
            }
            else break;
        }
        return true;
    }
    template<typename T,enable_if_t<is_integral<T>::value,int> =0>inline bool _read(T& x){
        if(!skip())return false;
        if(rdLeft+20>=rdRight)reload();
        bool neg=false;
        if(rdbuf[rdLeft]=='-'){
            neg=true;
            rdLeft++;
        }
        x=0;
        while(rdbuf[rdLeft]>='0' and rdLeft<rdRight){
            x=x*10+(neg?-(rdbuf[rdLeft++]^48):(rdbuf[rdLeft++]^48));
        }
        return true;
    }
    template<typename T,enable_if_t<is_floating_point<T>::value,int> =0>inline bool _read(T& x){
        if(!skip())return false;
        if(rdLeft+20>=rdRight)reload();
        bool neg=false;
        if(rdbuf[rdLeft]=='-'){
            neg=true;
            rdLeft++;
        }
        x=0;
        while(rdbuf[rdLeft]>='0' and rdbuf[rdLeft]<='9' and rdLeft<rdRight){
            x=x*10+(rdbuf[rdLeft++]^48);
        }
        if(rdbuf[rdLeft]!='.')return true;
        rdLeft++;
        T base=.1;
        while(rdbuf[rdLeft]>='0' and rdbuf[rdLeft]<='9' and rdLeft<rdRight){
            x+=base*(rdbuf[rdLeft++]^48);
            base*=.1;
        }
        if(neg)x=-x;
        return true;
    }
    inline bool _read(char& x){
        if(!skip())return false;
        if(rdLeft+1>=rdRight)reload();
        x=rdbuf[rdLeft++];
        return true;
    }
    inline bool _read(string& x){
        if(!skip())return false;
        for(;;){
            int pos=rdLeft;
            while(pos<rdRight and rdbuf[pos]>' ')pos++;
            x.append(rdbuf+rdLeft,pos-rdLeft);
            if(rdLeft==pos)break;
            rdLeft=pos;
            if(rdLeft==rdRight)reload();
            else break;
        }
        return true;
    }
    template<typename T>inline bool _read(vector<T>& v){
        for(auto& x:v){
            if(!_read(x))return false;
        }
        return true;
    }

    char wtbuf[L],tmp[50];
    int wtRight=0;
    inline void flush(){
        fwrite(wtbuf,1,wtRight,stdout);
        wtRight=0;
    }
    inline void _write(const char& x){
        if(wtRight>L-32)flush();
        wtbuf[wtRight++]=x;
    }
    inline void _write(const string& x){
        for(auto& c:x)_write(c);
    }
    template<typename T,enable_if_t<is_integral<T>::value,int> =0>inline void _write(T x){
        if(wtRight>L-32)flush();
        if(x==0){
            _write('0');
            return;
        }
        else if(x<0){
            _write('-');
            if (__builtin_expect(x == std::numeric_limits<T>::min(), 0)) {
                switch (sizeof(x)) {
                case 2: _write("32768"); return;
                case 4: _write("2147483648"); return;
                case 8: _write("9223372036854775808"); return;
                }
            }
            x=-x;
        }
        int pos=0;
        while(x!=0){
            tmp[pos++]=char((x%10)|48);
            x/=10;
        }
        rep(i,0,pos)wtbuf[wtRight+i]=tmp[pos-1-i];
        wtRight+=pos;
    }
    template<typename T>inline void _write(const vector<T>& v){
        rep(i,0,v.size()){
            if(i)_write(' ');
            _write(v[i]);
        }
    }
public:
    FastIO(){}
    ~FastIO(){flush();}
    inline void read(){}
    template <typename Head, typename... Tail>inline void read(Head& head,Tail&... tail){
        assert(_read(head));
        read(tail...); 
    }
    template<bool ln=true,bool space=false>inline void write(){if(ln)_write('\n');}
    template <bool ln=true,bool space=false,typename Head, typename... Tail>inline void write(const Head& head,const Tail&... tail){
        if(space)_write(' ');
        _write(head);
        write<ln,true>(tail...); 
    }
};

/**
 * @brief Fast IO
 */
#line 3 "sol.cpp"

#line 2 "library/Graph/hld.hpp"

struct HLD{
    using P=pair<int,int>;
    vector<vector<int>> g; vector<int> sz,in,out,rev,hs,par,dist;
    void dfs(int v,int p){
        par[v]=p; sz[v]=1;
        if(p!=-1)dist[v]=dist[p]+1;
        if(!g[v].empty() and g[v][0]==p)swap(g[v][0],g[v].back());
        for(auto& to:g[v])if(to!=p){
           dfs(to,v); sz[v]+=sz[to];
           if(sz[g[v][0]]<sz[to])swap(g[v][0],to);
        }
    }
    void dfs2(int v,int p,int& k){
        in[v]=k++; rev[in[v]]=v;
        for(auto& to:g[v])if(to!=p){
            hs[to]=(g[v][0]==to?hs[v]:to);
            dfs2(to,v,k);
        }
        out[v]=k;
    }
    HLD(int _n):g(_n),sz(_n),in(_n),out(_n),rev(_n),hs(_n),par(_n),dist(_n){}
    void add_edge(int u,int v){
        g[u].emplace_back(v); g[v].emplace_back(u);
    }
    void run(int rt=0){dfs(rt,-1); hs[rt]=rt; int k=0; dfs2(rt,-1,k);}
    int lca(int u,int v){
        for(;;v=par[hs[v]]){
            if(in[u]>in[v])swap(u,v);
            if(hs[u]==hs[v])return u;
        }
    }
    vector<P> get(int u,int p,bool es=0){
        assert(in[p]<=in[u] and out[u]<=out[p]);
        vector<P> res;
        while(hs[u]!=hs[p]){
            res.push_back({in[hs[u]],in[u]+1});
            u=par[hs[u]];
        }
        res.push_back({in[p]+es,in[u]+1});
        return res;
    }
    int jump(int u,int v,int k){
        if(k<0)return -1;
        int g=lca(u,v);
        int d0=dist[u]+dist[v]-dist[g]*2;
        if(d0<k)return -1;
        int st=u;
        if(dist[u]-dist[g]<k)st=v,k=d0-k;
        for(;;){
            int to=hs[st];
            if(in[st]-k>=in[to])return rev[in[st]-k];
            k-=in[st]-in[to]+1; st=par[to];
        }
    }
};

/**
 * @brief Heavy Light Decomposition
 */
#line 2 "library/DataStructure/lazysegtree.hpp"

template<typename M,typename N,M (*f)(M,M),M (*g)(M,N),N (*h)(N,N),M (*m1)(),N (*n1)()>
    class LazySegmentTree{
    int sz,height;
    vector<M> data;
    vector<N> lazy;
    void update(int k){data[k]=f(data[k*2],data[k*2+1]);}
    void apply(int k,N x){
        data[k]=g(data[k],x);
        if(k<sz)lazy[k]=h(lazy[k],x);
    }
    void down(int k){
        apply(k*2,lazy[k]);
        apply(k*2+1,lazy[k]);
        lazy[k]=n1();
    }
public:
    LazySegmentTree(int n=0):LazySegmentTree(vector<M>(n,m1())){}
    LazySegmentTree(const vector<M>& a){
        sz=1,height=0;
        while(sz<(int)a.size())sz<<=1,height++;
        data.assign(2*sz,m1());
        lazy.assign(sz,n1());
        rep(i,0,a.size())data[sz+i]=a[i];
        for(int i=sz-1;i;i--)update(i);
    }
    void set(int k,M x){
        k+=sz;
        for(int i=height;i;i--)down(k>>i);
        data[k]=x;
        for(int i=1;i<=height;i++)update(k>>i);
    }
    void update(int L,int R,N x){
        if(L>=R)return;
        L+=sz,R+=sz;
        for(int i=height;i;i--){
            if(((L>>i)<<i)!=L)down(L>>i);
            if(((R>>i)<<i)!=R)down((R-1)>>i);
        }
        int lb=L,rb=R;
        while(L<R){
            if(L&1)apply(L++,x);
            if(R&1)apply(--R,x);
            L>>=1;
            R>>=1;
        }
        L=lb,R=rb;
        for(int i=1;i<=height;i++){
            if(((L>>i)<<i)!=L)update(L>>i);
            if(((R>>i)<<i)!=R)update((R-1)>>i);
        }
    }
    M query(int L,int R){
        if(L>=R)return m1();
        L+=sz,R+=sz;
        for(int i=height;i;i--){
            if(((L>>i)<<i)!=L)down(L>>i);
            if(((R>>i)<<i)!=R)down((R-1)>>i);
        }
        M lb=m1(),rb=m1();
        while(L<R){
            if(L&1)lb=f(lb,data[L++]);
            if(R&1)rb=f(data[--R],rb);
            L>>=1;
            R>>=1;
        }
        return f(lb,rb);
    }
};

/**
 * @brief Lazy Segment Tree
 */
#line 2 "library/DataStructure/bit.hpp"

template<typename T>struct BIT{
    int n; T all=0; vector<T> val;
    BIT(int _n=0):n(_n),val(_n+10){}
    void clear(){val.assign(n+10,0); all=T();}
    void add(int i,T x){
        for(i++;i<=n;i+=(i&-i))val[i]=val[i]+x;
        all+=x;
    }
    T sum(int i){
        T res=0;
        for(;i;i-=(i&-i))res+=val[i];
        return res;
    }
    T sum(int L,int R){return sum(R)-sum(L);} // [L,R)
    int lower_bound(T x){
        int ret=0,len=1;
        while(2*len<=n)len<<=1;
        for(;len>=1;len>>=1){
            if(ret+len<=n and val[ret+len]<x){
                ret+=len;
                x-=val[ret];
            }
        }
        return ret;
    }
};

/**
 * @brief Binary Indexed Tree
 */
#line 7 "sol.cpp"

struct M{
    ll bsum,absum,a2bsum;
    M():bsum(0),absum(0),a2bsum(0){}
    M(ll _b):bsum(_b),absum(0),a2bsum(0){}
};

M f(M a,M b){
    M ret;
    ret.bsum=a.bsum+b.bsum;
    ret.absum=a.absum+b.absum;
    ret.a2bsum=a.a2bsum+b.a2bsum;
    return ret;
}
M g(M a,ll v){
    M ret=a;
    ret.absum+=a.bsum*v;
    ret.a2bsum+=v*2*a.absum+v*v*a.bsum;
    return ret;
}
ll h(ll a,ll b){return a+b;}
M m0(){return M();}
ll n0(){return 0;}

FastIO io;
int main(){
    int n,q;
    io.read(n,q);

    HLD hld(n);
    rep(_,0,n-1){
        int u,v;
        io.read(u,v);
        u--; v--;
        hld.add_edge(u,v);
    }
    hld.run();

    BIT<ll> cnt(n),sum(n);
    vector<M> tmp(n);
    tmp[0]=M(1);
    rep(i,1,n){
        tmp[hld.in[i]]=M((i+1)-(hld.par[i]+1));
    }
    LazySegmentTree<M,ll,f,g,h,m0,n0> seg(tmp);
    
    vector<int> S(n);
    auto change=[&](int x)->void{
        if(S[x]==0){
            S[x]=1;
            cnt.add(hld.in[x],1);
            sum.add(hld.in[x],x+1);
            auto ps=hld.get(x,0);
            for(auto& [L,R]:ps)seg.update(L,R,1);
        }
        else{
            S[x]=0;
            cnt.add(hld.in[x],-1);
            sum.add(hld.in[x],-(x+1));
            auto ps=hld.get(x,0);
            for(auto& [L,R]:ps)seg.update(L,R,-1);
        }
    };
    auto fixroot=[&](int x)->ll{
        ll ret=seg.query(hld.in[x],hld.out[x]).a2bsum;
        ret-=sum.sum(hld.in[x],hld.out[x]);
        if(x!=0){
            ll sz=cnt.sum(hld.in[x],hld.out[x]);
            ret+=(hld.par[x]+1)*sz*sz;
        }
        assert(ret%2==0);
        return ret/2;
    };
    auto same=[&](int x)->ll{
        ll ret=fixroot(0);
        auto ps=hld.get(x,0,1);
        M done=M();
        for(auto& [L,R]:ps){
            done=f(done,seg.query(L,R));
        }
        ret+=done.absum*cnt.all;
        ret-=done.a2bsum;
        return ret;
    };
    while(q--){
        int u,r,v;
        io.read(u,r,v);
        u--; r--; v--;
        change(u);

        ll ret;
        if(r==v)ret=same(r);
        else if(hld.in[r]<hld.in[v] or hld.out[v]<=hld.in[r])ret=fixroot(v);
        else{
            int x=hld.jump(v,r,1);
            ret=same(v)-fixroot(x);
            ll sz=cnt.sum(hld.in[x],hld.out[x]);
            ret-=ll(v+1)*sz*(cnt.all-sz);
        }
        io.write(ret);
    }
    return 0;
}
0