結果

問題 No.2341 Triple Tree Query (Medium)
ユーザー 👑 hos.lyrichos.lyric
提出日時 2023-06-02 22:51:44
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 3,020 ms / 5,000 ms
コード長 14,690 bytes
コンパイル時間 3,296 ms
コンパイル使用メモリ 158,572 KB
実行使用メモリ 37,112 KB
最終ジャッジ日時 2024-06-09 00:38:03
合計ジャッジ時間 35,646 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 1,379 ms
25,844 KB
testcase_03 AC 1,366 ms
25,728 KB
testcase_04 AC 1,437 ms
25,728 KB
testcase_05 AC 1,382 ms
25,856 KB
testcase_06 AC 1,378 ms
25,752 KB
testcase_07 AC 318 ms
34,412 KB
testcase_08 AC 323 ms
33,012 KB
testcase_09 AC 306 ms
37,112 KB
testcase_10 AC 310 ms
34,688 KB
testcase_11 AC 308 ms
36,216 KB
testcase_12 AC 322 ms
32,640 KB
testcase_13 AC 320 ms
34,048 KB
testcase_14 AC 304 ms
33,508 KB
testcase_15 AC 301 ms
36,224 KB
testcase_16 AC 324 ms
31,852 KB
testcase_17 AC 202 ms
26,484 KB
testcase_18 AC 193 ms
26,484 KB
testcase_19 AC 200 ms
26,448 KB
testcase_20 AC 205 ms
26,480 KB
testcase_21 AC 198 ms
26,488 KB
testcase_22 AC 187 ms
26,488 KB
testcase_23 AC 199 ms
26,488 KB
testcase_24 AC 197 ms
26,488 KB
testcase_25 AC 2,727 ms
26,216 KB
testcase_26 AC 2,759 ms
26,004 KB
testcase_27 AC 3,020 ms
26,212 KB
testcase_28 AC 2,755 ms
26,108 KB
testcase_29 AC 2,672 ms
25,984 KB
testcase_30 AC 267 ms
26,468 KB
testcase_31 AC 259 ms
26,464 KB
testcase_32 AC 261 ms
26,476 KB
testcase_33 AC 500 ms
26,080 KB
testcase_34 AC 486 ms
26,112 KB
testcase_35 AC 511 ms
25,960 KB
testcase_36 AC 537 ms
25,980 KB
testcase_37 AC 504 ms
25,980 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")

#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace std;

using Int = long long;

template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; ++i) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }


// fast IO by yosupo
// sc.read(string &) appends the input
struct Scanner {
    FILE* fp = nullptr;
    char line[(1 << 15) + 1];
    size_t st = 0, ed = 0;
    void reread() {
        memmove(line, line + st, ed - st);
        ed -= st;
        st = 0;
        ed += fread(line + ed, 1, (1 << 15) - ed, fp);
        line[ed] = '\0';
    }
    bool succ() {
        while (true) {
            if (st == ed) {
                reread();
                if (st == ed) return false;
            }
            while (st != ed && isspace(line[st])) st++;
            if (st != ed) break;
        }
        if (ed - st <= 50) reread();
        return true;
    }
    template <class T, enable_if_t<is_same<T, string>::value, int> = 0>
    bool read_single(T& ref) {
        if (!succ()) return false;
        while (true) {
            size_t sz = 0;
            while (st + sz < ed && !isspace(line[st + sz])) sz++;
            ref.append(line + st, sz);
            st += sz;
            if (!sz || st != ed) break;
            reread();            
        }
        return true;
    }
    template <class T, enable_if_t<is_integral<T>::value, int> = 0>
    bool read_single(T& ref) {
        if (!succ()) return false;
        bool neg = false;
        if (line[st] == '-') {
            neg = true;
            st++;
        }
        ref = T(0);
        while (isdigit(line[st])) {
            ref = 10 * ref + (line[st++] - '0');
        }
        if (neg) ref = -ref;
        return true;
    }
    template <class T> bool read_single(vector<T>& ref) {
        for (auto& d : ref) {
            if (!read_single(d)) return false;
        }
        return true;
    }
    void read() {}
    template <class H, class... T> void read(H& h, T&... t) {
        bool f = read_single(h);
        assert(f);
        read(t...);
    }
    Scanner(FILE* _fp) : fp(_fp) {}
};

struct Printer {
  public:
    template <bool F = false> void write() {}
    template <bool F = false, class H, class... T>
    void write(const H& h, const T&... t) {
        if (F) write_single(' ');
        write_single(h);
        write<true>(t...);
    }
    template <class... T> void writeln(const T&... t) {
        write(t...);
        write_single('\n');
    }

    Printer(FILE* _fp) : fp(_fp) {}
    ~Printer() { flush(); }

  private:
    static constexpr size_t SIZE = 1 << 15;
    FILE* fp;
    char line[SIZE], small[50];
    size_t pos = 0;
    void flush() {
        fwrite(line, 1, pos, fp);
        pos = 0;
    }
    void write_single(const char& val) {
        if (pos == SIZE) flush();
        line[pos++] = val;
    }
    template <class T, enable_if_t<is_integral<T>::value, int> = 0>
    void write_single(T val) {
        if (pos > (1 << 15) - 50) flush();
        if (val == 0) {
            write_single('0');
            return;
        }
        if (val < 0) {
            write_single('-');
            val = -val;  // todo min
        }
        size_t len = 0;
        while (val) {
            small[len++] = char('0' + (val % 10));
            val /= 10;
        }
        for (size_t i = 0; i < len; i++) {
            line[pos + i] = small[len - 1 - i];
        }
        pos += len;
    }
    void write_single(const string& s) {
        for (char c : s) write_single(c);
    }
    void write_single(const char* s) {
        size_t len = strlen(s);
        for (size_t i = 0; i < len; i++) write_single(s[i]);
    }
    template <class T> void write_single(const vector<T>& val) {
        auto n = val.size();
        for (size_t i = 0; i < n; i++) {
            if (i) write_single(' ');
            write_single(val[i]);
        }
    }
    void write_single(long double d){
		{
			long long v=d;
			write_single(v);
			d-=v;
		}
		write_single('.');
		for(int _=0;_<8;_++){
			d*=10;
			long long v=d;
			write_single(v);
			d-=v;
		}
    }
};

Scanner sc(stdin);
Printer pr(stdout);


////////////////////////////////////////////////////////////////////////////////
template <unsigned M_> struct ModInt {
  static constexpr unsigned M = M_;
  unsigned x;
  constexpr ModInt() : x(0U) {}
  constexpr ModInt(unsigned x_) : x(x_ % M) {}
  constexpr ModInt(unsigned long long x_) : x(x_ % M) {}
  constexpr ModInt(int x_) : x(((x_ %= static_cast<int>(M)) < 0) ? (x_ + static_cast<int>(M)) : x_) {}
  constexpr ModInt(long long x_) : x(((x_ %= static_cast<long long>(M)) < 0) ? (x_ + static_cast<long long>(M)) : x_) {}
  ModInt &operator+=(const ModInt &a) { x = ((x += a.x) >= M) ? (x - M) : x; return *this; }
  ModInt &operator-=(const ModInt &a) { x = ((x -= a.x) >= M) ? (x + M) : x; return *this; }
  ModInt &operator*=(const ModInt &a) { x = (static_cast<unsigned long long>(x) * a.x) % M; return *this; }
  ModInt &operator/=(const ModInt &a) { return (*this *= a.inv()); }
  ModInt pow(long long e) const {
    if (e < 0) return inv().pow(-e);
    ModInt a = *this, b = 1U; for (; e; e >>= 1) { if (e & 1) b *= a; a *= a; } return b;
  }
  ModInt inv() const {
    unsigned a = M, b = x; int y = 0, z = 1;
    for (; b; ) { const unsigned q = a / b; const unsigned c = a - q * b; a = b; b = c; const int w = y - static_cast<int>(q) * z; y = z; z = w; }
    assert(a == 1U); return ModInt(y);
  }
  ModInt operator+() const { return *this; }
  ModInt operator-() const { ModInt a; a.x = x ? (M - x) : 0U; return a; }
  ModInt operator+(const ModInt &a) const { return (ModInt(*this) += a); }
  ModInt operator-(const ModInt &a) const { return (ModInt(*this) -= a); }
  ModInt operator*(const ModInt &a) const { return (ModInt(*this) *= a); }
  ModInt operator/(const ModInt &a) const { return (ModInt(*this) /= a); }
  template <class T> friend ModInt operator+(T a, const ModInt &b) { return (ModInt(a) += b); }
  template <class T> friend ModInt operator-(T a, const ModInt &b) { return (ModInt(a) -= b); }
  template <class T> friend ModInt operator*(T a, const ModInt &b) { return (ModInt(a) *= b); }
  template <class T> friend ModInt operator/(T a, const ModInt &b) { return (ModInt(a) /= b); }
  explicit operator bool() const { return x; }
  bool operator==(const ModInt &a) const { return (x == a.x); }
  bool operator!=(const ModInt &a) const { return (x != a.x); }
  friend std::ostream &operator<<(std::ostream &os, const ModInt &a) { return os << a.x; }
};
////////////////////////////////////////////////////////////////////////////////

constexpr unsigned MO = 998244353;
using Mint = ModInt<MO>;


// graph: tree, vertex lists, modified (parent removed, heavy child first)
int N;
vector<vector<int>> graph;

vector<int> par, sz;
int zeit;
vector<int> dis, fin;
vector<int> head;
void dfsSz(int u) {
  sz[u] = 1;
  for (const int v : graph[u]) {
    graph[v].erase(find(graph[v].begin(), graph[v].end(), u));
    par[v] = u;
    dfsSz(v);
    sz[u] += sz[v];
  }
}
void dfsHLD(int u) {
  dis[u] = zeit++;
  const int deg = graph[u].size();
  if (deg > 0) {
    int vm = graph[u][0];
    int jm = 0;
    for (int j = 1; j < deg; ++j) {
      const int v = graph[u][j];
      if (sz[vm] < sz[v]) {
        vm = v;
        jm = j;
      }
    }
    swap(graph[u][0], graph[u][jm]);
    head[vm] = head[u];
    dfsHLD(vm);
    for (int j = 1; j < deg; ++j) {
      const int v = graph[u][j];
      head[v] = v;
      dfsHLD(v);
    }
  }
  fin[u] = zeit;
}
void runHLD(int rt) {
  par.assign(N, -1);
  sz.resize(N);
  zeit = 0;
  dis.resize(N);
  fin.resize(N);
  head.resize(N);
  dfsSz(rt);
  head[rt] = rt;
  dfsHLD(rt);
}

////////////////////////////////////////////////////////////////////////////////

int /*N,*/ Q;
vector<int> A, B;
vector<Mint> Z;

// DFS, BFS
vector<int> X, Y;

vector<pair<int, int>> ps;
struct Node {
  // z -> c z + d
  int l, r;
  int pos0, pos1;
  bool isX;
  int x0, x1, y0, y1;
  Mint c, d;
  void push(Node &ll, Node &rr) {
    if (!(c == 1 && d == 0)) {
      ll.ch(c, d);
      rr.ch(c, d);
      c = 1;
      d = 0;
    }
  }
  void ch(Mint cc, Mint dd) {
    c = cc * c;
    d = cc * d + dd;
  }
};
int nodesLen;
vector<Node> nodes;
int Build(int pos0, int pos1, bool isX) {
  const int u = nodesLen++;
  Node &f = nodes[u];
  f.l = f.r = -1;
  f.pos0 = pos0;
  f.pos1 = pos1;
  f.isX = isX;
  f.c = 1;
  f.d = 0;
  f.x0 = N; f.x1 = -1;
  f.y0 = N; f.y1 = -1;
  for (int j = pos0; j < pos1; ++j) {
    chmin(f.x0, ps[j].first ); chmax(f.x1, ps[j].first );
    chmin(f.y0, ps[j].second); chmax(f.y1, ps[j].second);
  }
  ++f.x1;
  ++f.y1;
  if (pos0 + 1 < pos1) {
    const int mid = (pos0 + pos1) / 2;
    if (isX) {
      nth_element(ps.begin() + pos0, ps.begin() + mid, ps.begin() + pos1,
          [&](const pair<int, int> &p0, const pair<int, int> &p1) -> bool {
            return (p0.first  < p1.first ); 
          });
    } else {
      nth_element(ps.begin() + pos0, ps.begin() + mid, ps.begin() + pos1,
          [&](const pair<int, int> &p0, const pair<int, int> &p1) -> bool {
            return (p0.second < p1.second); 
          });
    }
    f.l = Build(pos0, mid, !isX);
    f.r = Build(mid, pos1, !isX);
  }
// cerr<<u<<": "<<f.l<<" "<<f.r<<" "<<f.isX<<"; ";pv(ps.begin()+pos0,ps.begin()+pos1);
  return u;
}
void Ch(int u, int a0, int a1, int b0, int b1, Mint c, Mint d) {
  Node &f = nodes[u];
  if (f.x1 <= a0 || a1 <= f.x0 || f.y1 <= b0 || b1 <= f.y0) {
    return;
  }
  if (a0 <= f.x0 && f.x1 <= a1 && b0 <= f.y0 && f.y1 <= b1) {
    f.ch(c, d);
    return;
  }
  f.push(nodes[f.l], nodes[f.r]);
  Ch(f.l, a0, a1, b0, b1, c, d);
  Ch(f.r, a0, a1, b0, b1, c, d);
}
pair<Mint, Mint> Get(int u, int x, int y) {
  Node &f = nodes[u];
  if (!~f.l) {
    return make_pair(f.c, f.d);
  } else {
    f.push(nodes[f.l], nodes[f.r]);
    const Node &fl = nodes[f.l];
    return Get((f.isX ? (x < fl.x1) : (y < fl.y1)) ? f.l : f.r, x, y);
  }
}

int main() {
  {
    sc.read(N, Q);
    A.resize(N - 1);
    B.resize(N - 1);
    for (int i = 0; i < N - 1; ++i) {
      sc.read(A[i], B[i]);
      --A[i];
      --B[i];
    }
    Z.resize(N);
    for (int u = 0; u < N; ++u) {
      sc.read(Z[u].x);
    }
    
    graph.assign(N, {});
    for (int i = 0; i < N - 1; ++i) {
      graph[A[i]].push_back(B[i]);
      graph[B[i]].push_back(A[i]);
    }
    runHLD(0);
    
    X = dis;
    Y.assign(N, 0);
    vector<int> dep(N, 0);
    vector<vector<pair<int, int>>> layers(N);
    {
      int y = 0;
      queue<int> que;
      que.push(0);
      for (; !que.empty(); ) {
        const int u = que.front();
        que.pop();
        Y[u] = y++;
        layers[dep[u]].emplace_back(X[u], u);
        for (const int v : graph[u]) {
          dep[v] = dep[u] + 1;
          que.push(v);
        }
      }
    }
// cerr<<"graph = "<<graph<<endl;
// cerr<<"X = "<<X<<endl;
// cerr<<"Y = "<<Y<<endl;
// cerr<<"layers = "<<layers<<endl;
    
    ps.resize(N);
    for (int u = 0; u < N; ++u) {
      ps[u] = make_pair(X[u], Y[u]);
    }
    nodesLen = 0;
    nodes.resize(2 * N - 1);
    Build(0, N, true);
    
    int ls[21], rs[21];
    for (int q = 0; q < Q; ++q) {
      int O;
      sc.read(O);
      switch (O) {
        case 1: {
          int V;
          sc.read(V);
          --V;
// cerr<<"get "<<V<<endl;
          const auto res = Get(0, X[V], Y[V]);
          const Mint ans = res.first * Z[V] + res.second;
          pr.writeln(ans.x);
        } break;
        case 2: {
          int V, K;
          Mint C, D;
          sc.read(V, K, C.x, D.x);
          --V;
// cerr<<"near "<<V<<" "<<K<<" "<<C<<" "<<D<<endl;
          
          fill(ls, ls + (2*K+1), N);
          fill(rs, rs + (2*K+1), 0);
          {
            int u = V;
            for (int k = 0; k <= K; ++k) {
              for (int kk = 0; k + kk <= K; ++kk) {
                chmin(ls[K - k + kk], dis[u]);
                chmax(rs[K - k + kk], fin[u]);
              }
              u = par[u];
              if (!~u) break;
            }
          }
          for (int k = -K; k <= K; ++k) {
            const int d = dep[V] - k;
            if (0 <= d && d < N) {
              auto &layer = layers[d];
              const int lb = lower_bound(layer.begin(), layer.end(), make_pair(ls[K - k], -1)) - layer.begin();
              const int ub = lower_bound(layer.begin(), layer.end(), make_pair(rs[K - k], -1)) - layer.begin();
              if (lb < ub) {
// cerr<<"  ";pv(layer.begin()+lb,layer.begin()+ub);
                Ch(0, 0, N, Y[layer[lb].second], Y[layer[ub - 1].second] + 1, C, D);
              }
            }
          }
        } break;
        case 3: {
          int V;
          Mint C, D;
          sc.read(V, C.x, D.x);
          --V;
// cerr<<"sub "<<V<<" "<<C<<" "<<D<<endl;
          
          Ch(0, dis[V], fin[V], 0, N, C, D);
        } break;
        case 4: {
          int U, V;
          Mint C, D;
          sc.read(U, V, C.x, D.x);
          --U;
          --V;
// cerr<<"path "<<U<<" "<<V<<" "<<C<<" "<<D<<endl;
          
          auto oper = [&](int x0, int x1) -> void {
            ++x1;
            Ch(0, x0, x1, 0, N, C, D);
          };
          int u = U, v = V;
          for (; ; ) {
            if (head[u] == head[v]) {
              oper(min(dis[u], dis[v]), max(dis[u], dis[v]));
              break;
            }
            if (dep[head[u]] > dep[head[v]]) {
              oper(dis[head[u]], dis[u]);
              u = par[head[u]];
            } else {
              oper(dis[head[v]], dis[v]);
              v = par[head[v]];
            }
          }
        } break;
        default: assert(false);
      }
    }
  }
  return 0;
}
0