結果
問題 | No.2115 Making Forest Easy |
ユーザー | noshi91 |
提出日時 | 2022-10-28 23:10:55 |
言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 13 ms / 2,000 ms |
コード長 | 5,554 bytes |
コンパイル時間 | 2,101 ms |
コンパイル使用メモリ | 154,524 KB |
実行使用メモリ | 6,948 KB |
最終ジャッジ日時 | 2024-07-06 02:23:53 |
合計ジャッジ時間 | 3,806 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 50 |
ソースコード
// たすけてくれ template <unsigned m> struct modint { unsigned v; modint(long long x = 0) : v() { x %= m; if (x < 0) x += m; v = x; } modint operator+(modint r) const { return modint(*this) += r; } modint operator-(modint r) const { return modint(*this) -= r; } modint operator*(modint r) const { return modint(*this) *= r; } modint operator/(modint r) const { return modint(*this) /= r; } modint &operator+=(modint r) { v += r.v; if (v >= m) v -= m; return *this; } modint &operator-=(modint r) { if (v < r.v) v += m; v -= r.v; return *this; } modint &operator*=(modint r) { v = (unsigned long long)v * r.v % m; return *this; } modint &operator/=(modint r) { unsigned e = m - 2; while (e) { if (e & 1) *this *= r; r *= r; e >>= 1; } return *this; } }; using mint = modint<998244353>; struct qmi { mint c0, c1; qmi(mint c0_, mint c1_ = 0) : c0(c0_), c1(c1_) {} friend qmi operator+(const qmi &x, const qmi &y) { return {x.c0 + y.c0, x.c1 + y.c1}; } friend qmi operator*(const qmi &x, const qmi &y) { return {x.c0 * y.c0, x.c1 * y.c0 + x.c0 * y.c1}; } friend qmi operator*(const mint x, const qmi &y) { return {x * y.c0, x * y.c1}; } qmi &operator*=(const qmi &y) { return *this = *this * y; } qmi &operator+=(const qmi &y) { c0 += y.c0; c1 += y.c1; return *this; } }; #include <algorithm> #include <cassert> #include <iostream> #include <utility> #include <vector> #define rep(i, a, b) for (int i = a; i < int(b); i++) #include <random> std::mt19937_64 rng(0xE19E71EBC6129513ULL); using pri_t = typename std::mt19937_64::result_type; #include <memory> struct node_t; using ptr_t = std::unique_ptr<node_t>; struct node_t { ptr_t l, r; pri_t pri; qmi val, sum, psum, mul; int A; node_t(int A_, qmi val_) : l(), r(), pri(rng()), val(val_), sum(val_), psum(mint(A_) * val_), mul(1, 0), A(A_) {} }; qmi sum(const ptr_t &p) { return p ? p->mul * p->sum : qmi(0); } qmi psum(const ptr_t &p) { return p ? p->mul * p->psum : qmi(0); } pri_t pri(const ptr_t &p) { return p ? p->pri : 0; } void mult(ptr_t &p, const qmi v) { if (p) p->mul *= v; } void push(node_t &n) { n.val *= n.mul; n.sum *= n.mul; n.psum *= n.mul; mult(n.l, n.mul); mult(n.r, n.mul); n.mul = qmi(1, 0); } void fix(node_t &n) { n.sum = sum(n.l) + n.val + sum(n.r); n.psum = psum(n.l) + mint(n.A) * n.val + psum(n.r); } std::tuple<ptr_t, ptr_t, ptr_t> split(ptr_t p, const int A) { if (!p) return {nullptr, nullptr, nullptr}; push(*p); if (p->A < A) { auto [l, m, r] = split(std::move(p->r), A); p->r = std::move(l); fix(*p); return {std::move(p), std::move(m), std::move(r)}; } else if (p->A > A) { auto [l, m, r] = split(std::move(p->l), A); p->l = std::move(r); fix(*p); return {std::move(l), std::move(m), std::move(p)}; } else { auto l = std::move(p->l); auto r = std::move(p->r); fix(*p); return {std::move(l), std::move(p), std::move(r)}; } } ptr_t conv_(ptr_t x, qmi xs, ptr_t y, qmi ys) { if (pri(x) < pri(y)) { std::swap(x, y); std::swap(xs, ys); } if (!y) { mult(x, ys); return x; } push(*x); auto [l, m, r] = split(std::move(y), x->A); { qmi xs_ = xs, ys_ = ys; xs += sum(x->l); ys += sum(l); x->l = conv_(std::move(x->l), xs_, std::move(l), ys_); } xs += x->val; x->val *= ys; if (m) { x->val += m->val * xs; ys += m->val; } x->r = conv_(std::move(x->r), xs, std::move(r), ys); fix(*x); return x; } ptr_t conv(ptr_t x, ptr_t y) { return conv_(std::move(x), qmi(0), std::move(y), qmi(0)); } void insert_(ptr_t &p, ptr_t e) { if (pri(p) < e->pri) { auto [l, m, r] = split(std::move(p), e->A); if (m) e->val += m->val; e->l = std::move(l); e->r = std::move(r); fix(*e); p = std::move(e); return; } push(*p); if (p->A < e->A) { insert_(p->r, std::move(e)); } else if (p->A > e->A) { insert_(p->l, std::move(e)); } else { p->val += e->val; } fix(*p); } void insert(ptr_t &p, const int A, const qmi val) { insert_(p, std::make_unique<node_t>(A, val)); } #include <stack> int main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); int N; std::cin >> N; std::vector<int> A(N); for (auto &e : A) std::cin >> e; std::vector<std::vector<int>> g(N); rep(i, 0, N - 1) { int u, v; std::cin >> u >> v; u--; v--; g[u].push_back(v); g[v].push_back(u); } qmi ans = qmi(0); const mint half = mint(1) / 2; std::vector<int> parent(N, -1); std::vector<ptr_t> dp(N); std::vector<int> dfs; { std::stack<int> st; st.push(0); while (!st.empty()) { const int v = st.top(); st.pop(); dfs.push_back(v); for (const int u : g[v]) { if (u != parent[v]) { st.push(u); parent[u] = v; } } } } std::reverse(dfs.begin(), dfs.end()); for (const int v : dfs) { ptr_t ret = std::make_unique<node_t>(A[v], qmi(1, 1)); for (const int u : g[v]) { if (u == parent[v]) continue; auto t = std::move(dp[u]); ans += half * psum(t); mult(t, qmi(half)); insert(t, 0, qmi(half)); ret = conv(std::move(ret), std::move(t)); } dp[v] = std::move(ret); } ans += psum(dp[0]); rep(i, 0, N - 1) ans = 2 * ans; std::cout << ans.c1.v << "\n"; return 0; }