結果

問題 No.650 行列木クエリ
ユーザー pekempeypekempey
提出日時 2018-02-10 00:19:57
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 4,297 bytes
コンパイル時間 1,147 ms
コンパイル使用メモリ 86,704 KB
実行使用メモリ 32,140 KB
最終ジャッジ日時 2024-10-09 08:36:00
合計ジャッジ時間 4,224 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,820 KB
testcase_01 WA -
testcase_02 WA -
testcase_03 AC 2 ms
6,820 KB
testcase_04 WA -
testcase_05 AC 711 ms
29,960 KB
testcase_06 AC 2 ms
6,816 KB
testcase_07 AC 2 ms
6,816 KB
testcase_08 WA -
testcase_09 WA -
testcase_10 AC 2 ms
6,820 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <algorithm>
#include <vector>
#include <array>
#include <functional>

using namespace std;

const int mod = 1e9 + 7;

struct Modint {
  int n;
  Modint(int n = 0) : n(n) {}
};

Modint operator+(Modint a, Modint b) { return Modint((a.n += b.n) >= mod ? a.n - mod : a.n); }
Modint operator-(Modint a, Modint b) { return Modint((a.n -= b.n) < 0 ? a.n + mod : a.n); }
Modint operator*(Modint a, Modint b) { return Modint(1LL * a.n * b.n % mod); }
Modint &operator+=(Modint &a, Modint b) { return a = a + b; }
Modint &operator-=(Modint &a, Modint b) { return a = a - b; }
Modint &operator*=(Modint &a, Modint b) { return a = a * b; }

using V = array<Modint, 2>;
using M = array<V, 2>;

const M one = { V { 1, 0 }, V { 0, 1 } };

M operator *(M a, M b) {
  M res = {};
  for (int i = 0; i < 2; i++) {
    for (int k = 0; k < 2; k++) {
      for (int j = 0; j < 2; j++) {
        res[i][j] += a[i][k] * b[k][j];
      }
    }
  }
  return res;
}

// path product query
class PPQ {
  struct node {
    node *l = nullptr;
    node *r = nullptr;
    node *p = nullptr;
    bool rev = false;
    M prodX;
    M prodY;
    M val;
  };

  vector<node> dat;

public:
  PPQ(const vector<vector<int>> &g) {
    const int n = g.size();
    dat.resize(n);
    for (int i = 0; i < n; i++) {
      dat[i].prodX = one;
      dat[i].prodY = one;
      dat[i].val = one;
    }
    for (int i = 0; i < n; i++) {
      for (int j : g[i]) {
        if (i < j) {
          link(&dat[i], &dat[j]);
        }
      }
    }
  }

  void update(int u, M m) {
    evert(&dat[u]);
    expose(&dat[u]);
    dat[u].prodX = m;
    dat[u].prodY = m;
    dat[u].val = m;
  }

  M query(int u, int v) {
    evert(&dat[u]);
    expose(&dat[v]);
    return dat[v].prodX;
  }

private:
  bool is_root(node *x) {
    return !x->p || (x != x->p->l && x != x->p->r);
  }

  M prodX(node *x) {
    return x ? x->prodX : one;
  }

  M prodY(node *x) {
    return x ? x->prodY : one;
  }

  void pull(node *x) {
    if (!x->rev) {
      x->prodX = prodX(x->l) * x->val * prodX(x->r);
      x->prodY = prodY(x->r) * x->val * prodY(x->l);
    } else {
      x->prodX = prodY(x->r) * x->val * prodY(x->l);
      x->prodY = prodX(x->l) * x->val * prodX(x->r);
    }
  }

  void rot(node *x) {
    node *y = x->p, *z = y->p;
    if (z) {
      if (y == z->l) z->l = x;
      if (y == z->r) z->r = x;
    }
    x->p = z; y->p = x;
    if (x == y->l) {
      y->l = x->r; x->r = y;
      if (y->l) y->l->p = y;
    } else {
      y->r = x->l;
      x->l = y;
      if (y->r) y->r->p = y;
    }
    pull(y);
  }

  void reverse(node *x) {
    swap(x->l, x->r);
    x->rev ^= true;
  }

  void push(node *x) {
    if (x->rev) {
      if (x->l) reverse(x->l);
      if (x->r) reverse(x->r);
    }
    x->rev = false;
  }

  void flush(node *x) {
    if (x->p) flush(x->p);
    push(x);
  }

  void splay(node *x) {
    while (!is_root(x)) {
      node *y = x->p;
      if (!is_root(y)) {
        node *z = y->p;
        rot((x == y->l) == (y == z->l) ? y : x);
      }
      rot(x);
    }
    pull(x);
  }

  void expose(node *x) {
    node *tmp = x;
    flush(x);
    for (node *r = nullptr; x; x = x->p) {
      splay(x);
      x->r = r;
      pull(x);
      r = x;
    }
    splay(tmp);
  }

  void evert(node *x) {
    expose(x);
    reverse(x);
  }

  void link(node *c, node *p) {
    evert(c);
    expose(p);
    p->r = c;
    c->p = p;
    pull(p);
  }
};


int input() {
  int n;
  scanf("%d", &n);
  return n;
}

int main() {
  int n = input();
  
  vector<vector<int>> g(n * 2 - 1);
  for (int i = 0; i < n - 1; i++) {
    int u = input();
    int v = input();
    g[u].push_back(i + n);
    g[i + n].push_back(u);
    g[v].push_back(i + n);
    g[i + n].push_back(v);
  }

  PPQ ppq(g);

  int q = input();

  while (q--) {
    char c;
    scanf(" %c", &c);

    if (c == 'g') {
      int u = input();
      int v = input();
      M ans = ppq.query(u, v);
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
          printf("%d ", ans[i][j].n);
        }
      }
      printf("\n");
    } else {
      int u = input();
      M m;
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
          m[i][j] = input();
        }
      }
      ppq.update(u + n, m);
    }
  }
}
0