#include int ri() { int n; scanf("%d", &n); return n; } template struct ModInt{ int x; ModInt () : x(0) {} ModInt (int64_t x) : x(x >= 0 ? x % mod : (mod - -x % mod) % mod) {} ModInt &operator += (const ModInt &p){ if ((x += p.x) >= mod) x -= mod; return *this; } ModInt &operator -= (const ModInt &p) { if ((x += mod - p.x) >= mod) x -= mod; return *this; } ModInt &operator *= (const ModInt &p) { x = (int64_t) x * p.x % mod; return *this; } ModInt &operator /= (const ModInt &p) { *this *= p.inverse(); return *this; } ModInt &operator ^= (int64_t p) { ModInt res = 1; for (; p; p >>= 1) { if (p & 1) res *= *this; *this *= *this; } return *this = res; } ModInt operator - () const { return ModInt(-x); } ModInt operator + (const ModInt &p) const { return ModInt(*this) += p; } ModInt operator - (const ModInt &p) const { return ModInt(*this) -= p; } ModInt operator * (const ModInt &p) const { return ModInt(*this) *= p; } ModInt operator / (const ModInt &p) const { return ModInt(*this) /= p; } ModInt operator ^ (int64_t p) const { return ModInt(*this) ^= p; } bool operator == (const ModInt &p) const { return x == p.x; } bool operator != (const ModInt &p) const { return x != p.x; } explicit operator int() const { return x; } ModInt &operator = (const int p) { x = p; return *this;} ModInt inverse() const { int a = x, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; a -= t * b; std::swap(a, b); u -= t * v; std::swap(u, v); } return ModInt(u); } friend std::ostream & operator << (std::ostream &stream, const ModInt &p) { return stream << p.x; } friend std::istream & operator >> (std::istream &stream, ModInt &a) { int64_t x; stream >> x; a = ModInt(x); return stream; } }; typedef ModInt<998244353> mint; #define MAX_NODE (200000 * 20) template struct avl_map { static inline constexpr key_t KEY_MIN = std::numeric_limits::min(); static inline constexpr key_t KEY_MAX = std::numeric_limits::max(); struct Node { Node *l; Node *r; int size; int height; key_t key; value_t val; agg_t agg; op_t op; Node *fetch() { if (l != avl_map::NONE) l->flush(); if (r != avl_map::NONE) r->flush(); size = 1 + l->size + r->size; height = 1 + std::max(l->height, r->height); agg = l->agg + ((agg_t) val) + r->agg; return this; } Node *flush() { if (op != op_t()) { op.apply_val(val); op.apply_agg(agg); if (l) l->op = l->op * op; if (r) r->op = r->op * op; op = op_t(); } return this; } Node *rotate_l() { Node *new_root = r->flush(); r = new_root->l; new_root->l = this; return fetch(), new_root->fetch(); } Node *rotate_r() { Node *new_root = l->flush(); l = new_root->r; new_root->r = this; return fetch(), new_root->fetch(); } int height_diff() { return l->height - r->height; } Node *balance() { int dif = height_diff(); if (dif == 2) { if (l->flush()->height_diff() < 0) l = l->rotate_l(); return rotate_r(); } else if (dif == -2) { if (r->flush()->height_diff() > 0) r = r->rotate_r(); return rotate_l(); } else return fetch(); } }; static Node *nodes, *removed_tmp, *NONE; Node *root; static int head; avl_map (Node *root) : avl_map() { this->root = root; } avl_map () { if (!head) NONE = nodes = new Node[MAX_NODE], nodes[head++] = {NONE, NONE, 0, 0, key_t(), value_t(), agg_t(), op_t()}; root = NONE; } template static Node *insert(Node *node, const key_t &key, const value_t &val, const T &func) { if (node == NONE) return &(nodes[head++] = {NONE, NONE, 1, 1, key, val, val, op_t()}); node->flush(); if (key < node->key) node->l = insert(node->l, key, val, func); else if (key > node->key) node->r = insert(node->r, key, val, func); else func(node, val); return node->balance(); } static Node *remove_rightmost(Node *node) { node->flush(); if (node->r != NONE) { node->r = remove_rightmost(node->r); return node->balance(); } else return (removed_tmp = node)->l; } static Node *remove(Node *node, const key_t &key) { if (node == NONE) return node; node->flush(); if (key < node->key) node->l = remove(node->l, key); else if (key > node->key) node->r = remove(node->r, key); else { if (node->l == NONE) return node->r; node->l = remove_rightmost(node->l); removed_tmp->l = node->l; removed_tmp->r = node->r; return removed_tmp->balance(); } return node->balance(); } static Node *merge_with_root(Node *l, Node *m, Node *r) { int dif = l->height - r->height; if (-2 <= dif && dif <= 2) { m->l = l; m->r = r; return m->balance(); } else if (dif > 0) { l->flush(); l->r = merge_with_root(l->r, m, r); return l->balance(); } else { r->flush(); r->l = merge_with_root(l, m, r->l); return r->balance(); } } static Node *merge(Node *l, Node *r) { if (l == NONE) return r; if (r == NONE) return l; l = remove_rightmost(l); return merge_with_root(l, removed_tmp, r); } static std::pair split(Node *node, const key_t &key) { if (node == NONE) return {NONE, NONE}; node->flush(); Node *l = node->l; Node *r = node->r; node->l = node->r = NONE; if (key < node->key) { auto tmp = split(l, key); return {tmp.first, merge_with_root(tmp.second, node, r)}; } else if (key > node->key) { auto tmp = split(r, key); return {merge_with_root(l, node, tmp.first), tmp.second}; } else return {l, merge_with_root(NONE, node->fetch(), r)}; } static agg_t aggregate(Node *node, const key_t &key_l, const key_t &key_r) { agg_t res{}; if (node == NONE || key_r <= key_l) return res; node->flush(); if (key_l == KEY_MIN && key_r == KEY_MAX) return node->agg; if (key_l < node->key) res = res + aggregate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r); if (key_l <= node->key && key_r > node->key) res = res + (agg_t) node->val; if (key_r > node->key + 1) res = res + aggregate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r); return res; } static void operate(Node *node, const key_t &key_l, const key_t &key_r, const op_t &op) { if (node == NONE || key_r <= key_l) return; node->flush(); if (key_l == KEY_MIN && key_r == KEY_MAX) { node->op = op; node->flush(); return; } if (key_l < node->key) operate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r, op); if (key_l <= node->key && key_r > node->key) op.apply_val(node->val); if (key_r > node->key + 1) operate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r, op); node->fetch(); } static void dump(Node *node, std::vector > &res) { node->flush(); if (node->l != NONE) dump(node->l, res); res.push_back({node->key, node->val}); if (node->r != NONE) dump(node->r, res); } // member functions int size() const { return root->size; } void assign(const key_t &key, const value_t &val) { root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = val; }); } void add(const key_t &key, const value_t &val) { root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = ((agg_t) node->val) + (agg_t) val; }); } void remove(const key_t &key) { root = remove(root, key); } static avl_map merge(const std::vector &a) { Node *res = a[0].root; for (int i = 1; i < (int) a.size(); i++) res = merge(res, a[i].root); return avl_map(res); } std::pair split(const key_t &key) { auto tmp = split(root, key); return {avl_map(tmp.first), avl_map(tmp.second)}; } agg_t aggregate(const key_t &l, const key_t &r) { return aggregate(root, l, r); } agg_t aggregate_left(const key_t &r) { return aggregate(root, KEY_MIN, r); } agg_t aggregate_right(const key_t &l) { return aggregate(root, l, KEY_MAX); } void operate(const key_t &l, const key_t &r, const op_t &op) { operate(root, l, r, op); } void operate_left(const key_t &r, const op_t &op) { operate(root, KEY_MIN, r, op); } void operate_right(const key_t &l, const op_t &op) { operate(root, l, KEY_MAX, op); } std::vector > dump() const { std::vector > res; dump(root, res); return res; } static void reset() { if (head) head = 1; } void clear() { root = NONE; } }; template typename avl_map::Node *avl_map::nodes; template typename avl_map::Node *avl_map::removed_tmp; template typename avl_map::Node *avl_map::NONE; template int avl_map::head = 0; template std::array array_add(const std::array &a, const std::array &b) { std::array res; for (int i = 0; i < n; i++) res[i] = a[i] + b[i]; return res; } struct value_t { std::array vec{}; value_t operator + (const value_t &rhs) { std::array res; for (int i = 0; i < 4; i++) res[i] = vec[i] + rhs.vec[i]; return {res}; } }; struct op_t { std::array, 2> mat = {{{1, 0}, {0, 1}}}; static op_t zero() { return op_t{{{{0, 0}, {0, 0}}}}; }; template void apply(T &val) const { std::array tmp{}; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 0] += val.vec[i + 0] * mat[i][j]; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 2] += val.vec[i + 2] * mat[i][j]; val.vec = tmp; } void apply_val(value_t &val) const { apply(val); } void apply_agg(value_t &sum) const { apply(sum); } bool operator != (const op_t &rhs) const { return mat != rhs.mat; } op_t operator * (const op_t &rhs) const { op_t res; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) res.mat[i][j] = mat[i][0] * rhs.mat[0][j] + mat[i][1] * rhs.mat[1][j]; return res; } }; using map = avl_map; std::vector > hen; void merge(map &lhs, map &rhs) { if (lhs.size() < rhs.size()) std::swap(lhs, rhs); std::vector > > add_list; auto dump = rhs.dump(); for (int i = dump.size(); i--; ) { int key = dump[i].first; auto val = dump[i].second.vec; auto cur_sum = lhs.aggregate_left(key).vec; mint r0 = val[0] * cur_sum[0]; mint r1 = val[0] * cur_sum[1] + val[1] * cur_sum[0]; mint r2 = r0 * key; mint r3 = r1 * key; add_list.push_back({key, {r0, r1, r2, r3}}); } mint sum[4] = { 0 }; int last_key = map::KEY_MIN; for (int i = 0; i < (int) dump.size(); i++) { int key = dump[i].first; auto val = dump[i].second.vec; lhs.operate(last_key, key, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}}); last_key = key; for (int j = 0; j < 4; j++) sum[j] += val[j]; } lhs.operate(last_key, map::KEY_MAX, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}}); for (auto i : add_list) lhs.add(i.first, value_t{i.second}); } std::vector cnt; std::vector dp; std::vector a; std::vector power2; int n; mint res = 0; void dfs(int i, int prev) { dp[i].assign(a[i], value_t({1, 1, a[i], a[i]})); for (auto j : hen[i]) if (j != prev) { dfs(j, i); cnt[i] += cnt[j]; dp[j].assign(-1, value_t({power2[cnt[j] - 1], 0, 0, -power2[cnt[j] - 1]})); merge(dp[i], dp[j]); } res += dp[i].aggregate(map::KEY_MIN, map::KEY_MAX).vec[3] * power2[n - cnt[i] - (prev != -1)]; } int main() { n = ri(); a.resize(n); for (auto &i : a) i = ri(); hen.resize(n); for (int i = 1; i < n; i++) { int x = ri() - 1; int y = ri() - 1; hen[x].push_back(y); hen[y].push_back(x); } dp.resize(n); cnt.resize(n, 1); power2.resize(n + 1, 1); for (int i = 1; i <= n; i++) power2[i] = power2[i - 1] + power2[i - 1]; res = 0; dfs(0, -1); std::cout << res << std::endl; return 0; }