結果
問題 | No.2115 Making Forest Easy |
ユーザー |
![]() |
提出日時 | 2022-10-28 23:10:55 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 36 ms / 2,000 ms |
コード長 | 5,554 bytes |
コンパイル時間 | 2,732 ms |
コンパイル使用メモリ | 152,280 KB |
最終ジャッジ日時 | 2025-02-08 14:58:50 |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 50 |
ソースコード
// たすけてくれtemplate <unsigned m> struct modint {unsigned v;modint(long long x = 0) : v() {x %= m;if (x < 0)x += m;v = x;}modint operator+(modint r) const { return modint(*this) += r; }modint operator-(modint r) const { return modint(*this) -= r; }modint operator*(modint r) const { return modint(*this) *= r; }modint operator/(modint r) const { return modint(*this) /= r; }modint &operator+=(modint r) {v += r.v;if (v >= m)v -= m;return *this;}modint &operator-=(modint r) {if (v < r.v)v += m;v -= r.v;return *this;}modint &operator*=(modint r) {v = (unsigned long long)v * r.v % m;return *this;}modint &operator/=(modint r) {unsigned e = m - 2;while (e) {if (e & 1)*this *= r;r *= r;e >>= 1;}return *this;}};using mint = modint<998244353>;struct qmi {mint c0, c1;qmi(mint c0_, mint c1_ = 0) : c0(c0_), c1(c1_) {}friend qmi operator+(const qmi &x, const qmi &y) {return {x.c0 + y.c0, x.c1 + y.c1};}friend qmi operator*(const qmi &x, const qmi &y) {return {x.c0 * y.c0, x.c1 * y.c0 + x.c0 * y.c1};}friend qmi operator*(const mint x, const qmi &y) {return {x * y.c0, x * y.c1};}qmi &operator*=(const qmi &y) { return *this = *this * y; }qmi &operator+=(const qmi &y) {c0 += y.c0;c1 += y.c1;return *this;}};#include <algorithm>#include <cassert>#include <iostream>#include <utility>#include <vector>#define rep(i, a, b) for (int i = a; i < int(b); i++)#include <random>std::mt19937_64 rng(0xE19E71EBC6129513ULL);using pri_t = typename std::mt19937_64::result_type;#include <memory>struct node_t;using ptr_t = std::unique_ptr<node_t>;struct node_t {ptr_t l, r;pri_t pri;qmi val, sum, psum, mul;int A;node_t(int A_, qmi val_): l(), r(), pri(rng()), val(val_), sum(val_), psum(mint(A_) * val_),mul(1, 0), A(A_) {}};qmi sum(const ptr_t &p) { return p ? p->mul * p->sum : qmi(0); }qmi psum(const ptr_t &p) { return p ? p->mul * p->psum : qmi(0); }pri_t pri(const ptr_t &p) { return p ? p->pri : 0; }void mult(ptr_t &p, const qmi v) {if (p)p->mul *= v;}void push(node_t &n) {n.val *= n.mul;n.sum *= n.mul;n.psum *= n.mul;mult(n.l, n.mul);mult(n.r, n.mul);n.mul = qmi(1, 0);}void fix(node_t &n) {n.sum = sum(n.l) + n.val + sum(n.r);n.psum = psum(n.l) + mint(n.A) * n.val + psum(n.r);}std::tuple<ptr_t, ptr_t, ptr_t> split(ptr_t p, const int A) {if (!p)return {nullptr, nullptr, nullptr};push(*p);if (p->A < A) {auto [l, m, r] = split(std::move(p->r), A);p->r = std::move(l);fix(*p);return {std::move(p), std::move(m), std::move(r)};} else if (p->A > A) {auto [l, m, r] = split(std::move(p->l), A);p->l = std::move(r);fix(*p);return {std::move(l), std::move(m), std::move(p)};} else {auto l = std::move(p->l);auto r = std::move(p->r);fix(*p);return {std::move(l), std::move(p), std::move(r)};}}ptr_t conv_(ptr_t x, qmi xs, ptr_t y, qmi ys) {if (pri(x) < pri(y)) {std::swap(x, y);std::swap(xs, ys);}if (!y) {mult(x, ys);return x;}push(*x);auto [l, m, r] = split(std::move(y), x->A);{qmi xs_ = xs, ys_ = ys;xs += sum(x->l);ys += sum(l);x->l = conv_(std::move(x->l), xs_, std::move(l), ys_);}xs += x->val;x->val *= ys;if (m) {x->val += m->val * xs;ys += m->val;}x->r = conv_(std::move(x->r), xs, std::move(r), ys);fix(*x);return x;}ptr_t conv(ptr_t x, ptr_t y) {return conv_(std::move(x), qmi(0), std::move(y), qmi(0));}void insert_(ptr_t &p, ptr_t e) {if (pri(p) < e->pri) {auto [l, m, r] = split(std::move(p), e->A);if (m)e->val += m->val;e->l = std::move(l);e->r = std::move(r);fix(*e);p = std::move(e);return;}push(*p);if (p->A < e->A) {insert_(p->r, std::move(e));} else if (p->A > e->A) {insert_(p->l, std::move(e));} else {p->val += e->val;}fix(*p);}void insert(ptr_t &p, const int A, const qmi val) {insert_(p, std::make_unique<node_t>(A, val));}#include <stack>int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int N;std::cin >> N;std::vector<int> A(N);for (auto &e : A)std::cin >> e;std::vector<std::vector<int>> g(N);rep(i, 0, N - 1) {int u, v;std::cin >> u >> v;u--;v--;g[u].push_back(v);g[v].push_back(u);}qmi ans = qmi(0);const mint half = mint(1) / 2;std::vector<int> parent(N, -1);std::vector<ptr_t> dp(N);std::vector<int> dfs;{std::stack<int> st;st.push(0);while (!st.empty()) {const int v = st.top();st.pop();dfs.push_back(v);for (const int u : g[v]) {if (u != parent[v]) {st.push(u);parent[u] = v;}}}}std::reverse(dfs.begin(), dfs.end());for (const int v : dfs) {ptr_t ret = std::make_unique<node_t>(A[v], qmi(1, 1));for (const int u : g[v]) {if (u == parent[v])continue;auto t = std::move(dp[u]);ans += half * psum(t);mult(t, qmi(half));insert(t, 0, qmi(half));ret = conv(std::move(ret), std::move(t));}dp[v] = std::move(ret);}ans += psum(dp[0]);rep(i, 0, N - 1) ans = 2 * ans;std::cout << ans.c1.v << "\n";return 0;}