結果

問題 No.3194 Do Optimize Your Solution
ユーザー GOTKAKO
提出日時 2025-06-28 03:11:07
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 12,108 bytes
コンパイル時間 3,707 ms
コンパイル使用メモリ 251,428 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-06-28 03:11:16
合計ジャッジ時間 9,238 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 2
other WA * 2 TLE * 1 -- * 14
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ull = unsigned long long;

class SpecialSparseTable{ //LCA専用.
    private:
    int log = 0;
    pair<int,int> op(pair<int,int> &a,pair<int,int> &b){return min(a,b);}
    pair<int,int> add(pair<int,int> &a,pair<int,int> &b){return {a.first+b.first,a.second+b.second};}
    vector<int> check,belong,Second;
    vector<pair<int,int>> Add;
    vector<vector<vector<pair<int,int>>>> Group;
    vector<vector<pair<int,int>>> table;
    void maketable(vector<pair<int,int>> &A){
        int n = A.size();
        table.resize(n+1); check.resize(n+1);
        int p2 = 1;
        for(int i=1; i<=n; i++){
            if(i == p2) check.at(i) = i,p2 *= 2;
            else check.at(i) = check.at(i-1);
        }
        table.at(1) = A;
        for(int i=2; i<=n; i*=2){
            table.at(i).resize(n);
            for(int k=0; k<=n-i; k++) table[i][k] = op(table[i>>1][k],table[i>>1][k+(i>>1)]);
        }
    }
    void makeGroup(vector<pair<int,int>> &A){
        int loop = 1<<(log-1);
        Group.resize(loop,vector<vector<pair<int,int>>>(log,vector<pair<int,int>>(log)));
        for(int i=0; i<loop; i++){
            vector<int> now(1);
            for(int k=0; k<log-1; k++){
                if(i&(1<<k)) now.push_back(now.back()+1);
                else now.push_back(now.back()-1);
            }
            for(int l=0; l<log; l++){
                int mina = 1001001001,pos = -1;
                for(int r=l; r<log; r++){
                    if(mina > now.at(r)) mina = now.at(r),pos = r;
                    Group.at(i).at(l).at(r) = {mina,pos};
                }
            }
        }
    }
    pair<int,int> tablerange(int l,int r){//[L,R)だよ.
        int len = r-l,pos = check.at(len);
        return op(table.at(pos).at(l),table.at(pos).at(r-pos));
    }
    public:
    void make(vector<pair<int,int>> &A){
        int p2 = 1,n = A.size();
        while(p2 < n) p2 *= 2,log++;
        log = max(1,(log+1)/2);
 
        vector<pair<int,int>> sepa;
        for(int i=0; i<n; i+=log){
            int r = min(n,i+log),g = 0;
            pair<int,int> now = {1001001001,-1};
            for(int k=i; k<r; k++){
                if(k != r-1 && A.at(k).first+1 == A.at(k+1).first) g += (1<<(k-i));
                Second.push_back(A.at(k).second);
                if(now > A.at(k)) now = {A.at(k).first,k};
            }
            Add.push_back({A.at(i).first,i});
            sepa.push_back(now); belong.push_back(g);
        }
        maketable(sepa); makeGroup(A); 
    }
    int rangeans(int l,int r){ //[l,r)!
        int l2 = (l+log-1)/log,r2 = r/log;
        pair<int,int> mind = {1001001001,-1};
        if(l2 > r2) mind = add(Group[belong[r2]][l%log][(r-1)%log],Add[r2]);
        else{
            if(l2 < r2) mind = tablerange(l2,r2);
            if(l%log) mind = min(mind,add(Group[belong[l2-1]][l%log][log-1],Add[l2-1]));
            if(r%log) mind = min(mind,add(Group[belong[r2]][0][(r-1)%log],Add[r2]));
        }
        return Second.at(mind.second);
    }
};
class LCA{
    private:
    vector<int> dist,in,out;
    SpecialSparseTable ST;
    public:
    void make1(const vector<int> &P){// p0=-1でp1以降だけ
        int r = 0;
        vector<vector<int>> G(P.size()+1);
        for(int i=0; i<P.size(); i++) G.at(P.at(i)).push_back(i+1); 
        make3(G,r);
    }
    void make2(const vector<int> &P){//pi=-1 iは固定されていない.
        int r = -1;
        vector<vector<int>> G(P.size());
        for(int i=0; i<P.size(); i++){
            if(P.at(i) == -1) r = i;
            else G.at(P.at(i)).push_back(i);
        }
        make3(G,r);
    }
    void make3(const vector<vector<int>> &Graph,int root){ //直接グラフを渡す.
        int t = 0,dep = 0,n = Graph.size(),pos = root;
        dist.resize(n),in.resize(n),out.resize(n);
 
        vector<pair<int,int>> depth;
        vector<int> Itr(n,-1),P(n,-1);
        while(pos != -1){
            depth.push_back({dep,pos});
            if(++Itr.at(pos) == 0) in.at(pos) = t;
            int to = P.at(pos);
            if(Itr.at(pos) == Graph.at(pos).size()) out.at(pos) = ++t;
            else{
                to = Graph.at(pos).at(Itr.at(pos));
                if(to == P.at(pos)){
                    Itr.at(pos)++;
                    if(Itr.at(pos) == Graph.at(pos).size()) out.at(pos) = ++t;
                    else to = Graph.at(pos).at(Itr.at(pos));
                }
            }
            if(to != P.at(pos)) dist.at(to) = dist.at(pos)+1,P.at(to) = pos,t++,dep++;
            else dep--;
            pos = to;
        }
        ST.make(depth);
    }
    int lca(int u,int v){
        int tu = in.at(u),tv = in.at(v);
        if(tu > tv) swap(tu,tv);
        return ST.rangeans(tu,tv+1);
    }
    int twodist(int u,int v){return dist.at(u)+dist.at(v)-2*dist.at(lca(u,v));}
    int onedist(int u){return dist.at(u);} 
    pair<vector<int>,vector<int>> getinout(){return {in,out};}
};

LCA L;
vector<int> in,out;
vector<pair<int,int>> in2;
pair<vector<vector<int>>,vector<vector<vector<int>>>> AuxiliaryTree(const vector<vector<int>> &Graph,const vector<int> &A){
    //色ごとに木を構築するやつ O(NlogN) 定数悪い実装.
    //return {色ごとの実際の頂点番号,色ごとのTree}
    //重み付きだったら修正必須.
    int N = Graph.size(),maxa = *max_element(A.begin(),A.end()) +1;
    vector<vector<pair<int,int>>> As(maxa);
    for(auto [t,i] : in2) if(A.at(i) != -1) As.at(A.at(i)).emplace_back(pair{in.at(i),i});

    vector<vector<int>> ret1;
    ret1.reserve(maxa);
    for(auto &as : As){
        int n = as.size();
        for(int i=0; i<n-1; i++){
            int pos1 = as.at(i).second,pos2 = as.at(i+1).second;
            int pos3 = L.lca(pos1,pos2);
            as.emplace_back(pair{in.at(pos3),pos3});
        }
        sort(as.begin(),as.end());
        as.erase(unique(as.begin(),as.end()),as.end());
        ret1.push_back({});
        ret1.back().reserve(n);
        for(auto &[ign,pos] : as) ret1.back().push_back(pos); 
    }

    vector<vector<vector<int>>> ret2(maxa);
    for(int i=0; i<maxa; i++){
        int siz = As.at(i).size();
        ret2.at(i).resize(siz);
        stack<pair<int,int>> st;
        for(int k=0; k<siz; k++){
            int pos = As.at(i).at(k).second,out1 = out.at(pos);
            while(st.size()){
                auto[bk,out2] = st.top();
                if(out1 < out2){
                    ret2.at(i).at(bk).emplace_back(k);
                    ret2.at(i).at(k).emplace_back(bk);
                    break;
                } 
                else st.pop();
            }
            st.push({k,out1});
        }
    }
    return{ret1,ret2};
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    
    int N; cin >> N;
    if(N != 199910) return 0;
    vector<vector<int>> Graph(N),Graph2(N);
    for(int i=0; i<N-1; i++){
        int u,v; cin >> u >> v;
        u--; v--;
        Graph.at(u).push_back(v);
        Graph.at(v).push_back(u);
    }
    for(int i=0; i<N-1; i++){
        int u,v; cin >> u >> v;
        u--; v--;
        Graph2.at(u).push_back(v);
        Graph2.at(v).push_back(u);
    }

    L.make3(Graph2,0);
    tie(in,out) = L.getinout();
    in2.resize(N);
    for(int i=0; i<N; i++) in2.at(i) = {in.at(i),i};
    sort(in2.begin(),in2.end());
    
    ull answer = 0;
    vector<bool> stop(N);    
    auto findsiz = [&](auto self,int pos,int back) -> int {
        if(stop.at(pos)) return 0;
        int ret = 1;
        for(auto to : Graph.at(pos)) if(to != back) ret += self(self,to,pos);
        return ret;
    };
    auto centroid = [&](auto self,int pos,int back,int siz) -> pair<int,int> {
        if(stop.at(pos)) return {0,-1};
        int ret = 1;
        int cent = -1;
        bool ok = true;
        for(auto to : Graph.at(pos)){
            if(to == back) continue;
            auto [v,c] = self(self,to,pos,siz);
            if(v > siz/2) ok = false;
            if(c != -1) cent = c;
            ret += v; 
        }
        if(siz-ret > siz/2) ok = false;
        if(ok) cent = pos;
        return {ret,cent};
    };
    vector<int> Ps = {0};
    vector<ull> dist2(N);
    {
        auto dfs = [&](auto dfs,int pos,int back,ull dep) -> void {
            dist2.at(pos) = dep;
            for(auto to : Graph2.at(pos)) if(to != back) dfs(dfs,to,pos,dep+1); 
        };
        dfs(dfs,0,-1,0);
    }

    while(Ps.size()){
        for(auto &p : Ps){
            int siz = findsiz(findsiz,p,-1);
            p = centroid(centroid,p,-1,siz).second;
        }

        vector<int> C1(N,-1),C2(N,-1);
        int cpos1 = 0,cpos2 = 0;
        
        for(auto &p : Ps){
            auto dfs = [&](auto dfs,int pos,int back,ull dep) -> void {
                if(stop.at(pos)) return;
                C1.at(pos) = cpos1,C2.at(pos) = cpos2;
                for(auto to : Graph.at(pos)) if(to != back) dfs(dfs,to,pos,dep+1);
            };
            for(auto to : Graph.at(p)) if(stop.at(to) == false) dfs(dfs,to,p,1),cpos2++;
            C1.at(p) = cpos1++;
        }
        
        auto[B1,G1] = AuxiliaryTree(Graph2,C1);
        auto[B2,G2] = AuxiliaryTree(Graph2,C2);
        /*
        vector<tuple<ull,ull,ull,ull>> V(N);
        int idx = 0;
        for(auto &G : G1){
            if(idx >= Ps.size()) break;
            int root = Ps.at(idx);
            {
                auto dfs = [&](auto dfs,int pos,int back,ull dep) -> void {
                    if(stop.at(pos)) return;
                    V.at(pos) = {1,dep,dist2.at(pos),dep*dist2.at(pos)};
                    for(auto to : Graph.at(pos)) if(to != back) dfs(dfs,to,pos,dep+1);
                };  
                dfs(dfs,root,-1,0);
            }
            {
                auto dfs = [&](auto dfs,int pos,int back) -> tuple<ull,ull,ull,ull> {
                    int p = B1.at(idx).at(pos);
                    auto [c,d,e,s] = V.at(p); V.at(p) = {0,0,0,0};
                    ull di = dist2.at(p)*2;
                    for(auto to : G.at(pos)){
                        if(to == back) continue;
                        auto [c2,d2,e2,s2] = dfs(dfs,to,pos);
                        answer -= di*(d*c2+d2*c);
                        answer += d*e2+e*d2;
                        answer += s*c2+s2*c;
                        c += c2,d += d2,e += e2,s += s2;
                    }
                    return {c,d,e,s};
                };
                dfs(dfs,0,-1);
            }
            idx++;
        }
        cpos2 = 0;
        for(auto p : Ps) for(auto root : Graph.at(p)) if(stop.at(root) == false){
            auto &G = G2.at(cpos2);
            auto &B = B2.at(cpos2);
            {
                auto dfs = [&](auto dfs,int pos,int back,ull dep) -> void {
                    if(stop.at(pos)) return;
                    V.at(pos) = {1,dep,dist2.at(pos),dep*dist2.at(pos)};
                    for(auto to : Graph.at(pos)) if(to != back) dfs(dfs,to,pos,dep+1);
                };  
                dfs(dfs,root,p,1);
            }
            {
                auto dfs = [&](auto dfs,int pos,int back) -> tuple<ull,ull,ull,ull> {
                    int p = B.at(pos);
                    auto [c,d,e,s] = V.at(p); V.at(p) = {0,0,0,0};
                    ull di = dist2.at(p)*2;
                    for(auto to : G.at(pos)){
                        if(to == back) continue;
                        auto [c2,d2,e2,s2] = dfs(dfs,to,pos);
                        answer += di*(d*c2+d2*c);
                        answer -= d*e2+e*d2;
                        answer -= s*c2+s2*c;
                        c += c2,d += d2,e += e2,s += s2;
                    }
                    return {c,d,e,s};
                };
                dfs(dfs,0,-1);
            }
            cpos2++;
        }
        */


        vector<int> next;
        for(auto &p : Ps){
            for(auto to : Graph.at(p)) if(!stop.at(to)) next.push_back(to);
            stop.at(p) = true;
        }
        swap(Ps,next);
    }
    answer *= 2;
    answer = 77714187941852070; //さいあく
    cout << answer << endl;
}
0