結果

問題 No.650 行列木クエリ
ユーザー Ricky_ponRicky_pon
提出日時 2020-09-19 11:56:06
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 531 ms / 2,000 ms
コード長 9,676 bytes
コンパイル時間 2,335 ms
コンパイル使用メモリ 210,028 KB
実行使用メモリ 30,112 KB
最終ジャッジ日時 2023-09-05 14:21:37
合計ジャッジ時間 5,242 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,376 KB
testcase_01 AC 145 ms
7,976 KB
testcase_02 AC 531 ms
28,156 KB
testcase_03 AC 2 ms
4,376 KB
testcase_04 AC 148 ms
7,872 KB
testcase_05 AC 518 ms
28,220 KB
testcase_06 AC 2 ms
4,384 KB
testcase_07 AC 2 ms
4,380 KB
testcase_08 AC 169 ms
8,456 KB
testcase_09 AC 333 ms
30,112 KB
testcase_10 AC 1 ms
4,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#define For(i, a, b) for (int(i) = (int)(a); (i) < (int)(b); ++(i))
#define rFor(i, a, b) for (int(i) = (int)(a)-1; (i) >= (int)(b); --(i))
#define rep(i, n) For((i), 0, (n))
#define rrep(i, n) rFor((i), (n), 0)
#define fi first
#define se second
using namespace std;
typedef long long lint;
typedef unsigned long long ulint;
typedef pair<int, int> pii;
typedef pair<lint, lint> pll;
template <class T>
bool chmax(T &a, const T &b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}
template <class T>
bool chmin(T &a, const T &b) {
    if (a > b) {
        a = b;
        return true;
    }
    return false;
}
template <class T>
T div_floor(T a, T b) {
    if (b < 0) a *= -1, b *= -1;
    return a >= 0 ? a / b : (a + 1) / b - 1;
}
template <class T>
T div_ceil(T a, T b) {
    if (b < 0) a *= -1, b *= -1;
    return a > 0 ? (a - 1) / b + 1 : a / b;
}

constexpr lint mod = 1000000007;
constexpr lint INF = mod * mod;
constexpr int MAX = 200010;

template <typename T, typename E, typename F, typename G>
struct LinkCutTree {
    struct Node {
        Node *par_ptr, *left_ptr, *right_ptr;
        int idx;
        bool rev;
        T val, left_sum, right_sum;

        Node()
            : idx(-1),
              par_ptr(nullptr),
              left_ptr(nullptr),
              right_ptr(nullptr),
              rev(false),
              val(T{}),
              left_sum(T{}),
              right_sum(T{}) {}

        bool is_root() {
            return !par_ptr ||
                   (par_ptr->left_ptr != this && par_ptr->right_ptr != this);
        }

        void print_data() {
            int p = par_ptr ? par_ptr->idx : -1;
            int l = left_ptr ? left_ptr->idx : -1;
            int r = right_ptr ? right_ptr->idx : -1;
            printf(
                "id=%d, p=%d, l=%d, r=%d, rev=%d, val=%lld, lsum=%lld, "
                "rsum=%lld\n",
                idx, p, l, r, rev, val, left_sum, right_sum);
        }
    };

    int n;
    vector<Node> nodes;
    F f;
    G g;

    LinkCutTree(int n, T et, F f, G g) : n(n), f(f), g(g) {
        nodes.resize(n);
        rep(i, n) {
            nodes[i].idx = i;
            nodes[i].val = nodes[i].left_sum = nodes[i].right_sum = et;
        }
    }

    void toggle(Node *node) {
        if (!node->rev) return;
        if (node->left_ptr) node->left_ptr->rev ^= true;
        if (node->right_ptr) node->right_ptr->rev ^= true;
        swap(node->left_ptr, node->right_ptr);
        swap(node->left_sum, node->right_sum);
        node->rev = false;
    }

    void pull(Node *node) {
        node->left_sum = node->right_sum = node->val;
        if (node->left_ptr) {
            toggle(node->left_ptr);
            node->left_sum = f(node->left_ptr->left_sum, node->left_sum);
            node->right_sum = f(node->right_sum, node->left_ptr->right_sum);
        }
        if (node->right_ptr) {
            toggle(node->right_ptr);
            node->left_sum = f(node->left_sum, node->right_ptr->left_sum);
            node->right_sum = f(node->right_ptr->right_sum, node->right_sum);
        }
    }

    void rotl(Node *node) {
        Node *par = node->par_ptr, *grand_par = par->par_ptr;
        if ((par->right_ptr = node->left_ptr)) node->left_ptr->par_ptr = par;
        node->left_ptr = par;
        par->par_ptr = node;
        pull(par);
        pull(node);
        if ((node->par_ptr = grand_par)) {
            if (grand_par->left_ptr == par) grand_par->left_ptr = node;
            if (grand_par->right_ptr == par) grand_par->right_ptr = node;
            pull(grand_par);
        }
    }

    void rotr(Node *node) {
        Node *par = node->par_ptr, *grand_par = par->par_ptr;
        if ((par->left_ptr = node->right_ptr)) node->right_ptr->par_ptr = par;
        node->right_ptr = par;
        par->par_ptr = node;
        pull(par);
        pull(node);
        if ((node->par_ptr = grand_par)) {
            if (grand_par->left_ptr == par) grand_par->left_ptr = node;
            if (grand_par->right_ptr == par) grand_par->right_ptr = node;
            pull(grand_par);
        }
    }

    void toggle_all(Node *node) {
        if (node->is_root()) {
            toggle(node);
            return;
        }
        toggle_all(node->par_ptr);
        toggle(node);
    }

    void splay(int i) {
        Node *node = &nodes[i];
        toggle_all(node);
        while (!node->is_root()) {
            Node *par = node->par_ptr;
            if (par->is_root()) {
                if (par->right_ptr == node)
                    rotl(node);
                else
                    rotr(node);
                return;
            }

            Node *grand_par = par->par_ptr;
            if (par->left_ptr == node) {
                if (grand_par->left_ptr == par)
                    rotr(par), rotr(node);
                else
                    rotr(node), rotl(node);
            } else {
                if (grand_par->right_ptr == par)
                    rotl(par), rotl(node);
                else
                    rotl(node), rotr(node);
            }
        }
    }

    Node *expose(int i) {
        Node *child = nullptr;
        for (Node *par = &nodes[i]; par; par = par->par_ptr) {
            splay(par->idx);
            par->right_ptr = child;
            pull(par);
            child = par;
        }
        splay(i);
        return child;
    }

    void link(int child, int par) {
        expose(child);
        expose(par);
        nodes[child].par_ptr = &nodes[par];
        nodes[par].right_ptr = &nodes[child];
        pull(&nodes[par]);
    }

    void cut(int child) {
        expose(child);
        Node *par = nodes[child].left_ptr;
        nodes[child].left_ptr = nullptr;
        par->par_ptr = nullptr;
        pull(&nodes[child]);
    }

    void evert(int i) {
        expose(i);
        nodes[i].rev = true;
        toggle(&nodes[i]);
    }

    void add_edge(int child, int par) {
        evert(par);
        evert(child);
        link(child, par);
    }

    void del_edge(int child, int par) {
        evert(par);
        cut(child);
    }

    void update_val(int i, E x) {
        evert(i);
        nodes[i].val = g(nodes[i].val, x);
        pull(&nodes[i]);
    }

    T get_path_sum(int u, int v) {
        evert(u);
        expose(v);
        return nodes[v].left_sum;
    }
};

template <int_fast64_t MOD>
struct modint {
    using i64 = int_fast64_t;
    i64 a;
    modint(const i64 a_ = 0) : a(a_) {
        if (a > MOD)
            a %= MOD;
        else if (a < 0)
            (a %= MOD) += MOD;
    }
    modint inv() {
        i64 t = 1, n = MOD - 2, x = a;
        while (n) {
            if (n & 1) (t *= x) %= MOD;
            (x *= x) %= MOD;
            n >>= 1;
        }
        modint ret(t);
        return ret;
    }
    bool operator==(const modint x) const { return a == x.a; }
    bool operator!=(const modint x) const { return a != x.a; }
    modint operator+(const modint x) const { return modint(*this) += x; }
    modint operator-(const modint x) const { return modint(*this) -= x; }
    modint operator*(const modint x) const { return modint(*this) *= x; }
    modint operator/(const modint x) const { return modint(*this) /= x; }
    modint operator^(const lint x) const { return modint(*this) ^= x; }
    modint &operator+=(const modint &x) {
        a += x.a;
        if (a >= MOD) a -= MOD;
        return *this;
    }
    modint &operator-=(const modint &x) {
        a -= x.a;
        if (a < 0) a += MOD;
        return *this;
    }
    modint &operator*=(const modint &x) {
        (a *= x.a) %= MOD;
        return *this;
    }
    modint &operator/=(modint x) {
        (a *= x.inv().a) %= MOD;
        return *this;
    }
    modint &operator^=(lint n) {
        i64 ret = 1;
        while (n) {
            if (n & 1) (ret *= a) %= MOD;
            (a *= a) %= MOD;
            n >>= 1;
        }
        a = ret;
        return *this;
    }
    modint operator-() const { return modint(0) - *this; }
    modint &operator++() { return *this += 1; }
    modint &operator--() { return *this -= 1; }
    bool operator<(const modint x) const { return a < x.a; }
};

using mint = modint<1000000007>;

vector<mint> fact;
vector<mint> revfact;

void setfact(int n) {
    fact.resize(n + 1);
    revfact.resize(n + 1);
    fact[0] = 1;
    rep(i, n) fact[i + 1] = fact[i] * mint(i + 1);

    revfact[n] = fact[n].inv();
    for (int i = n - 1; i >= 0; i--) revfact[i] = revfact[i + 1] * mint(i + 1);
}

mint getC(int n, int r) {
    if (n < r) return 0;
    return fact[n] * revfact[r] * revfact[n - r];
}

using mat = array<mint, 4>;

int main() {
    int n;
    scanf("%d", &n);
    auto f = [](mat a, mat b) {
        return mat{a[0] * b[0] + a[1] * b[2], a[0] * b[1] + a[1] * b[3],
                   a[2] * b[0] + a[3] * b[2], a[2] * b[1] + a[3] * b[3]};
    };
    auto g = [](mat a, mat b) { return b; };
    mat et = {1, 0, 0, 1};
    LinkCutTree<mat, mat, decltype(f), decltype(g)> lct(n * 2 - 1, et, f, g);
    rep(i, n - 1) {
        int a, b;
        scanf("%d%d", &a, &b);
        lct.add_edge(a, n + i);
        lct.add_edge(b, n + i);
    }
    lct.evert(0);

    int q;
    scanf("%d", &q);
    rep(_, q) {
        char c;
        scanf(" %c", &c);
        if (c == 'x') {
            int id;
            scanf("%d", &id);
            mat a;
            rep(i, 4) scanf("%lld", &a[i].a);
            lct.update_val(n + id, a);
        } else {
            int u, v;
            scanf("%d%d", &u, &v);
            auto a = lct.get_path_sum(u, v);
            rep(i, 4) printf("%lld ", a[i].a);
            printf("\n");
        }
    }
}
0