結果

問題 No.399 動的な領主
ユーザー xuzijian629xuzijian629
提出日時 2019-10-28 03:44:05
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 11,913 bytes
コンパイル時間 2,605 ms
コンパイル使用メモリ 218,888 KB
実行使用メモリ 11,884 KB
最終ジャッジ日時 2024-09-14 21:10:24
合計ジャッジ時間 7,648 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 2 ms
5,248 KB
testcase_03 AC 2 ms
5,376 KB
testcase_04 AC 5 ms
5,376 KB
testcase_05 AC 32 ms
5,376 KB
testcase_06 AC 482 ms
11,880 KB
testcase_07 AC 478 ms
11,876 KB
testcase_08 AC 454 ms
11,880 KB
testcase_09 AC 448 ms
11,880 KB
testcase_10 AC 5 ms
5,376 KB
testcase_11 AC 21 ms
5,376 KB
testcase_12 AC 268 ms
11,884 KB
testcase_13 AC 252 ms
11,876 KB
testcase_14 AC 83 ms
11,880 KB
testcase_15 AC 224 ms
11,880 KB
testcase_16 WA -
testcase_17 AC 445 ms
11,832 KB
testcase_18 AC 447 ms
11,748 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

// T0: 元の配列のモノイド
// T1: T0に対する作用素モノイド
template <class T0, class T1>
class BaseLinkCutTree {
    // T0上の演算、単位元
    virtual T0 f0(T0, T0) = 0;
    const T0 u0;
    // T1上の演算、単位元
    virtual T1 f1(T1, T1) = 0;
    const T1 u1;
    // T0に対するT1の作用
    virtual T0 g(T0, T1) = 0;
    // 多数のt1(T1)に対するf1の合成
    virtual T1 p(T1, int) = 0;

    struct Node {
        int idx;
        T0 value, acc;
        T1 lazy;
        Node *l, *r, *p;
        bool rev;
        int cnt;

        bool is_root() { return !p || (p->l != this && p->r != this); }

        Node(int idx_, T0 value_, T0 u0_, T1 u1_)
            : idx(idx_), value(value_), acc(u0_), lazy(u1_), cnt(1), l(nullptr), r(nullptr), p(nullptr), rev(false) {}
    };

    vector<Node *> nodes;

    void toggle(Node *t) {
        assert(t);
        swap(t->l, t->r);
        t->rev ^= true;
    }

    void pushdown(Node *t) {
        if (t->lazy != u1) {
            if (t->l) {
                t->l->lazy = f1(t->l->lazy, t->lazy);
                t->l->value = g(t->l->value, p(t->lazy, 1));
                t->l->acc = g(t->l->acc, p(t->lazy, t->l->cnt));
            }
            if (t->r) {
                t->r->lazy = f1(t->r->lazy, t->lazy);
                t->r->value = g(t->r->value, p(t->lazy, 1));
                t->r->acc = g(t->r->acc, p(t->lazy, t->r->cnt));
            }
            t->lazy = u1;
        }
        if (t->rev) {
            if (t->l) toggle(t->l);
            if (t->r) toggle(t->r);
            t->rev = false;
        }
    }

    void pushup(Node *t) {
        t->cnt = 1;
        t->acc = t->value;
        if (t->l) t->cnt += t->l->cnt, t->acc = f0(t->l->acc, t->acc);
        if (t->r) t->cnt += t->r->cnt, t->acc = f0(t->acc, t->r->acc);
    }

    void rotr(Node *t) {
        auto *x = t->p, *y = x->p;
        if ((x->l = t->r)) t->r->p = x;
        t->r = x, x->p = t;
        pushup(x), pushup(t);
        if ((t->p = y)) {
            if (y->l == x) y->l = t;
            if (y->r == x) y->r = t;
            pushup(y);
        }
    }

    void rotl(Node *t) {
        auto *x = t->p, *y = x->p;
        if ((x->r = t->l)) t->l->p = x;
        t->l = x, x->p = t;
        pushup(x), pushup(t);
        if ((t->p = y)) {
            if (y->l == x) y->l = t;
            if (y->r == x) y->r = t;
            pushup(y);
        }
    }

    void splay(Node *t) {
        pushdown(t);
        while (!t->is_root()) {
            auto *q = t->p;
            if (q->is_root()) {
                pushdown(q), pushdown(t);
                if (q->l == t)
                    rotr(t);
                else
                    rotl(t);
            } else {
                auto *r = q->p;
                pushdown(r), pushdown(q), pushdown(t);
                if (r->l == q) {
                    if (q->l == t)
                        rotr(q), rotr(t);
                    else
                        rotl(t), rotr(t);
                } else {
                    if (q->r == t)
                        rotl(q), rotl(t);
                    else
                        rotr(t), rotl(t);
                }
            }
        }
    }

    Node *expose(Node *t) {
        Node *rp = nullptr;
        for (Node *cur = t; cur; cur = cur->p) {
            splay(cur);
            cur->r = rp;
            pushup(cur);
            rp = cur;
        }
        splay(t);
        return rp;
    }

    void link(Node *child, Node *parent) {
        expose(child);
        expose(parent);
        child->p = parent;
        parent->r = child;
        pushup(parent);
    }

    void cut(Node *child) {
        expose(child);
        auto *parent = child->l;
        child->l = nullptr;
        parent->p = nullptr;
        pushup(child);
    }

    void evert(Node *t) {
        expose(t);
        toggle(t);
        pushdown(t);
    }

    Node *lca(Node *u, Node *v) {
        if (get_root(u) != get_root(v)) return nullptr;
        expose(u);
        return expose(v);
    }

    Node *get_kth(Node *x, int k) {
        expose(x);
        while (x) {
            pushdown(x);
            if (x->r && x->r->sz > k) {
                x = x->r;
            } else {
                if (x->r) k -= x->r->sz;
                if (k == 0) return x;
                k -= 1;
                x = x->l;
            }
        }
        return nullptr;
    }

    Node *get_root(Node *x) {
        expose(x);
        while (x->l) {
            pushdown(x);
            x = x->l;
        }
        return x;
    }

    vector<Node *> get_path(Node *x) {
        vector<Node *> vs;
        function<void(Node *)> dfs = [&](Node *cur) {
            if (!cur) return;
            pushdown(cur);
            dfs(cur->r);
            vs.push_back(cur);
            dfs(cur->l);
        };
        expose(x);
        dfs(x);
        return vs;
    }

    void update_to_root(Node *t, T1 x) {
        expose(t);
        t->lazy = f1(t->lazy, x);
        t->acc = g(t->acc, p(x, t->cnt));
        t->value = g(t->value, p(x, 1));
        pushdown(t);
    }

    // childとparentを結ぶ
    void link(int child, int parent) {
        assert(!lca(nodes[child], nodes[parent]));
        link(nodes[child], nodes[parent]);
    }

    // childをparentから切り離す
    void cut(int child) {
        assert(lca(nodes[child], nodes[child]->p));
        cut(nodes[child]);
    }

public:
    BaseLinkCutTree(T0 u0_, T1 u1_) : u0(u0_), u1(u1_) {}

    // 値vで新しい頂点を登録してそのインデックスを返す
    int make_new_node(T0 v) {
        int idx = nodes.size();
        nodes.push_back(new Node(idx, v, u0, u1));
        return idx;
    }

    // aとbを結ぶ
    void connect(int a, int b) {
        assert(lca(a, b) == -1);
        evert(a);
        link(b, a);
    }

    // aとbを切り離す
    void disconnect(int a, int b) {
        assert(lca(a, b) != -1);
        evert(a);
        cut(b);
    }

    // vを根にする
    void evert(int v) { evert(nodes[v]); }

    // lcaを返す。非連結の場合は-1
    int lca(int u, int v) {
        auto r = lca(nodes[u], nodes[v]);
        if (!r) return -1;
        return r->idx;
    }

    // vから根方向にk個たどった頂点を返す
    int get_kth(int v, int k) {
        auto r = get_kth(nodes[v], k);
        if (!r) return -1;
        return r->idx;
    }

    // vの根を返す
    int get_root(int v) { return get_root(nodes[v])->idx; }

    // vの累積を求める
    T0 query(int v) {
        expose(nodes[v]);
        return nodes[v]->value;
    }

    // vにxを作用させる
    void update(int v, T1 x) {
        int r = get_root(v);
        evert(v);
        update_to_root(nodes[v], x);
        evert(r);
    }

    // uv間のパスに現れる頂点の累積を求める
    T0 query(int u, int v) {
        int r = get_root(u);
        evert(u);
        T0 ret = nodes[v]->acc;
        evert(r);
        return ret;
    }

    // uv間のパスに現れる頂点にxを作用させる
    void update(int u, int v, T1 x) {
        int r = get_root(u);
        evert(u);
        update_to_root(nodes[v], x);
        evert(r);
    }

    // vから根までのパスに現れる頂点を順に出力
    vector<int> get_path_to_root(int v) {
        vector<int> ret;
        for (const auto &u : get_path(nodes[v])) {
            ret.push_back(u->idx);
        }
        return ret;
    }

    // uvパスに現れる頂点を順に出力
    vector<int> get_path(int u, int v) {
        int r = get_root(u);
        evert(v);
        auto ret = get_path_to_root(u);
        evert(r);
        return ret;
    }
};

template <class T0, class T1>
struct MinUpdateQuery : public BaseLinkCutTree<T0, T1> {
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    MinUpdateQuery() : MinUpdateQuery(numeric_limits<T0>::max(), numeric_limits<T1>::min()) {}
    T0 f0(T0 x, T0 y) override { return min(x, y); }
    T1 f1(T1 x, T1 y) override { return y == numeric_limits<T1>::min() ? x : y; }
    T0 g(T0 x, T1 y) override { return y == numeric_limits<T1>::min() ? x : y; }
    T1 p(T1 x, int len) override { return x; }
};

template <class T0, class T1>
struct SumAddQuery : public BaseLinkCutTree<T0, T1> {
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    SumAddQuery() : SumAddQuery(0, 0) {}
    T0 f0(T0 x, T0 y) override { return x + y; }
    T1 f1(T1 x, T1 y) override { return x + y; }
    T0 g(T0 x, T1 y) override { return x + y; }
    T1 p(T1 x, int len) override { return x * len; }
};

template <class T0, class T1>
struct MinAddQuery : public BaseLinkCutTree<T0, T1> {
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    MinAddQuery() : MinAddQuery(numeric_limits<T0>::max(), 0) {}
    T0 f0(T0 x, T0 y) override { return min(x, y); }
    T1 f1(T1 x, T1 y) override { return x + y; }
    T0 g(T0 x, T1 y) override { return x + y; }
    T1 p(T1 x, int len) override { return x; }
};

template <class T0, class T1>
struct SumUpdateQuery : public BaseLinkCutTree<T0, T1> {
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    SumUpdateQuery() : SumUpdateQuery(0, numeric_limits<T1>::min()) {}
    T0 f0(T0 x, T0 y) override { return x + y; }
    T1 f1(T1 x, T1 y) override { return y == numeric_limits<T1>::min() ? x : y; }
    T0 g(T0 x, T1 y) override { return y == numeric_limits<T1>::min() ? x : y; }
    T1 p(T1 x, int len) override { return x == numeric_limits<T1>::min() ? numeric_limits<T1>::min() : x * len; }
};

template <class T0>
struct SumAffineQuery : public BaseLinkCutTree<T0, pair<T0, T0>> {
    using T1 = pair<T0, T0>;  // first * x + second
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    SumAffineQuery() : SumAffineQuery(0, {1, 0}) {}
    T0 f0(T0 x, T0 y) override { return x + y; }
    T1 f1(T1 x, T1 y) override { return {x.first * y.first, x.second * y.first + y.second}; }
    T0 g(T0 x, T1 y) override { return y.first * x + y.second; }
    T1 p(T1 x, int len) override { return {x.first, x.second * len}; }
    // update(i, j, {a, b}); // [i, j)にax + bを作用
    // update(i, j, {0, a}); // update
    // update(i, j, {1, a}); // 加算
    // update(i, j, {a, 0}); // 倍
};

template <class T>
struct MinmaxAffineQuery : public BaseLinkCutTree<pair<T, T>, pair<T, T>> {
    using T0 = pair<T, T>;  // {min, max}
    using T1 = pair<T, T>;  // first * x + second
    using BaseLinkCutTree<T0, T1>::BaseLinkCutTree;
    MinmaxAffineQuery()
        : MinmaxAffineQuery({numeric_limits<T>::max(), -numeric_limits<T>::max()}, {1, 0}) {
    }  // TODO: _u1を使うとコンパイル通らない原因不明
    T0 f0(T0 x, T0 y) override { return {min(x.first, y.first), max(x.second, y.second)}; }
    T1 f1(T1 x, T1 y) override { return {x.first * y.first, x.second * y.first + y.second}; }
    T0 g(T0 x, T1 y) override {
        T0 ret = {x.first * y.first + y.second, x.second * y.first + y.second};
        if (y.first < 0) swap(ret.first, ret.second);
        return ret;
    }
    T1 p(T1 x, int len) override { return x; }
    // update(i, j, {a, b}); // [i, j)にax + bを作用
    // update(i, j, {0, a}); // update
    // update(i, j, {1, a}); // 加算
    // update(i, j, {a, 0}); // 倍
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    SumAddQuery<long long, int> lct;
    int n;
    cin >> n;
    for (int i = 0; i < n; i++) lct.make_new_node(0);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        lct.connect(a, b);
    }

    int q;
    cin >> q;
    while (q--) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        lct.update(a, b, 1);
    }

    long long ans = 0;
    for (int i = 0; i < n; i++) {
        int k = lct.query(i);
        ans += 1LL * k * (k + 1) / 2;
    }
    cout << ans << endl;
}
0