結果

問題 No.2676 A Tourist
ユーザー 👑 nu50218nu50218
提出日時 2024-02-15 22:34:36
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 6,371 bytes
コンパイル時間 1,912 ms
コンパイル使用メモリ 120,848 KB
実行使用メモリ 129,592 KB
最終ジャッジ日時 2024-03-13 21:31:22
合計ジャッジ時間 12,558 ms
ジャッジサーバーID
(参考情報)
judge13 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,676 KB
testcase_01 AC 2,553 ms
40,788 KB
testcase_02 TLE -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

// 当初のwriter解

#include <cassert>
#include <cmath>
#include <iostream>
#include <stack>
#include <vector>
using namespace std;




template<class T>
class SparseTable{
    vector<vector<T>> table; //[i][x] := 左端が i の、連続する 2^x 個の区間の merge

    T merge(T a , T b){
        if(a < b)return a;
        else return b;
    }
    int N;
    vector<T> A;

    public:

    SparseTable(){}
    SparseTable(vector<T>a_){
        A = a_;
        N = A.size();
        int k = 0;
        int s = 1;
        while(s < N){
            s *= 2;
            k += 1;
        }
        
        table.resize(N , vector<T>(k+1));
        for(int i = 0 ; i < N ; i++)table[i][0] = A[i];

        for(int x = 1 ; x <= k ; x++){
            for(int i = 0 ; i < N ; i++){
                int mid = i + (1<<(x-1));
                if(mid >= N)table[i][x] = table[i][x-1];
                else table[i][x] = merge(table[i][x-1] , table[mid][x-1]);
            }
        }
    }

    // [l,r)
    T query(int l , int r){
        assert(l < r);
        int range = r-l;
        int wid = 1;
        int k = 0;
        while(wid <= range){
            k++;
            wid *= 2;
        }
        wid/=2;
        k--;
        int l2 = r-wid;
        if(l2 >= N)return table[l][k];

        return merge(table[l][k] , table[l2][k]);
    }
};


class EulerTour{
    vector<vector<int> >G;// 使い捨て
    public :
    vector<int> tour;// ET
    vector<int> Left;// [x] := tour において x が初めて登場する index

    EulerTour(){}
    EulerTour(vector<vector<int> > G_ , int root = 1):G(G_){
        stack<int> S;// (now , parent)       
        S.emplace(root);
        Left.resize(G.size());
        vector<int> parent(G.size() , -1);
        
        while(!S.empty()){
            int now = S.top();
            S.pop();
            if(now == -1)break;
            tour.push_back(now);

            if(G[now].size() != 0 && G[now].back() == parent[now])G[now].pop_back();

            if(G[now].size() == 0)S.push(parent[now]);
            else{
                parent[G[now].back()] = now;
                S.push(G[now].back());
                G[now].pop_back();
            }
        }

        for(int i = int(tour.size()) - 1 ; i >= 0 ;i--){
            Left[tour[i]] = i;
        }
    }
};




int N;
int root = 1;
vector<long long> a;          // [u] := 頂点 u の価値
vector<long long> a_adj;      // [u] := 頂点 u に隣接する頂点の価値の和(自身を除く)
vector<vector<int> > G;       // 隣接リスト
vector<long long> depth;      // [u] := 頂点 u の深さ
vector<long long> sum;        // [u] := root から u までパス上の各頂点に対して価値の総和
vector<long long> sum_adj;    // [u] := root から u までパス上の各頂点に対して、「その頂点に隣接する頂点の価値の和(自身を除く)」の 総和


void reconstruct();

EulerTour E;
SparseTable<pair<int,int>> S;

void init(){
    reconstruct();
    vector<pair<int,int>> A;
    E = EulerTour(G,root);
    for(int i = 0 ; i <int(E.tour.size()) ; i++)A.emplace_back(depth[E.tour[i]] , E.tour[i]);
    S = SparseTable(A);
}


int LCA(int u, int v) {
    int l = min(E.Left[u] , E.Left[v]);
    int r = max(E.Left[u] , E.Left[v]);
    return S.query(l,r+1).second;
}

// r を根とした場合の u と v の LCA
int LCA(int r, int u, int v) {
    return LCA(u, v) ^ LCA(u, r) ^ LCA(v, r);
}

long long query(int u, int v) {
    if (u == v) return a_adj[u] + a[u];
    long long res = 0;
    int lca = LCA(u, v);

    // パス上の各頂点の「隣接する頂点の和」の和
    res += sum_adj[u] + sum_adj[v] - 2 * sum_adj[lca] + a_adj[lca];

    // 求めたいものと比べて、重複している部分を引く
    res -= sum[u] + sum[v] - 2 * sum[lca] + a[lca];

    // 弾きすぎた部分を元に戻す
    res += a[u] + a[v];
    return res;
}

// u-v パスに頂点 x が隣接しているかどうか
bool is_adjacent(int u, int v, int x) {
    int lca = LCA(x, u, v);  // x を根とした場合の u , v の LCA
    int d = depth[x] + depth[lca] - 2 * depth[LCA(x, lca)];
    if (d <= 1)
        return true;
    else
        return false;
}



void reconstruct() {
    for (long long &x : sum) x = 0;
    for (long long &x : sum_adj) x = 0;

    stack<int> s;
    s.push(root);
    vector<bool> memo(N + 1, false);

    sum[root] = a[root];
    sum_adj[root] = a_adj[root];

    depth[root] = 0;

    // dfs で色々と計算する
    while (!s.empty()) {
        int now = s.top();
        s.pop();
        memo[now] = true;
        for (int nx : G[now]) {
            if (memo[nx]) continue;
            s.push(nx);
            sum[nx] = sum[now] + a[nx];
            sum_adj[nx] = sum_adj[now] + a_adj[nx];
            depth[nx] = depth[now] + 1;
        }
    }
}

int main() {
    int Q;
    cin >> N >> Q;

    G.resize(N + 1);
    a.resize(N + 1, 0);
    a_adj.resize(N + 1, 0);
    sum.resize(N + 1, 0);
    sum_adj.resize(N + 1, 0);
    depth.resize(N + 1, 0);

    for (int i = 0; i < N; i++) cin >> a[i + 1];

    for (int i = 0; i < N - 1; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }

    for (int u = 1; u < N + 1; u++) {
        for (int nx : G[u]) a_adj[nx] += a[u];
    }

    init();


    int border = sqrt(N);
    vector<pair<int, long long> > query_stack;

    for (int c = 1; c <= Q; c++) {
        if (c % border == 0) {
            while (int(query_stack.size()) > 0) {
                int x = query_stack.back().first;
                long long v = query_stack.back().second;
                query_stack.pop_back();
                a[x] += v;
            }
            for (int i = 0; i < N + 1; i++) a_adj[i] = 0;
            for (int u = 1; u < N + 1; u++) {
                for (int nx : G[u]) a_adj[nx] += a[u];
            }
            reconstruct();
        }

        int t, u, v;
        cin >> t >> u >> v;
        if (t == 0) {
            query_stack.emplace_back(u, (long long)v);
        } else {
            long long res = query(u, v);
            for (pair<int, long long> data : query_stack) {
                if (is_adjacent(u, v, data.first)) {
                    res += data.second;
                }
            }
            cout << res << endl;
        }
    }

    return 0;
}
0