#if defined(LOCAL) #include #else #include #endif #include #pragma GCC optimize("Ofast") //#pragma GCC target("avx2") #pragma GCC optimize("unroll-loops") using namespace std; //#include //#include //namespace mp=boost::multiprecision; //#define mulint mp::cpp_int //#define mulfloat mp::cpp_dec_float_100 struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<=0;(i)--) #define flc(x) __builtin_popcountll(x) #define pint pair #define pdouble pair #define plint pair #define fi first #define se second #define all(x) x.begin(),x.end() //#define vec vector #define nep(x) next_permutation(all(x)) typedef long long lint; int dx[8]={1,1,0,-1,-1,-1,0,1}; int dy[8]={0,1,1,1,0,-1,-1,-1}; const int MAX_N=3e5+5; templatebool chmax(T &a,const T &b){if(abool chmin(T &a,const T &b){if(b bucket[MAX_N/1000]; //constexpr int MOD=1000000007; constexpr int MOD=998244353; #include using namespace atcoder; typedef __int128_t llint; //https://ei1333.github.io/luzhiled/snippets/structure/red-black-tree.html template< class T > struct ArrayPool { vector< T > pool; vector< T * > stock; int ptr; ArrayPool(int sz) : pool(sz), stock(sz) {} inline T *alloc() { return stock[--ptr]; } inline void free(T *t) { stock[ptr++] = t; } void clear() { ptr = (int) pool.size(); for(int i = 0; i < pool.size(); i++) stock[i] = &pool[i]; } }; template< class D, class L, D (*f)(D, D), D (*g)(D, L), L (*h)(L, L), L (*p)(L, int) > struct RedBlackTree { enum COLOR { BLACK, RED }; struct Node { Node *l, *r; COLOR color; int level, cnt; D key, sum; L lazy; Node() {} Node(const D &k, const L &laz) : key(k), sum(k), l(nullptr), r(nullptr), color(BLACK), level(0), cnt(1), lazy(laz) {} Node(Node *l, Node *r, const D &k, const L &laz) : key(k), color(RED), l(l), r(r), lazy(laz) {} }; ArrayPool< Node > pool; const D M1; const L OM0; RedBlackTree(int sz, const D &M1, const L &OM0) : pool(sz), M1(M1), OM0(OM0) { pool.clear(); } inline Node *alloc(const D &key) { return &(*pool.alloc() = Node(key, OM0)); } inline Node *alloc(Node *l, Node *r) { auto t = &(*pool.alloc() = Node(l, r, M1, OM0)); return update(t); } virtual Node *clone(Node *t) { return t; } inline int count(const Node *t) { return t ? t->cnt : 0; } inline D sum(const Node *t) { return t ? t->sum : M1; } Node *update(Node *t) { t->cnt = count(t->l) + count(t->r) + (!t->l || !t->r); t->level = t->l ? t->l->level + (t->l->color == BLACK) : 0; t->sum = f(f(sum(t->l), t->key), sum(t->r)); return t; } Node *propagate(Node *t) { t = clone(t); if(t->lazy != OM0) { if(!t->l) { t->key = g(t->key, p(t->lazy, 1)); } else { if(t->l) { t->l = clone(t->l); t->l->lazy = h(t->l->lazy, t->lazy); t->l->sum = g(t->l->sum, p(t->lazy, count(t->l))); } if(t->r) { t->r = clone(t->r); t->r->lazy = h(t->r->lazy, t->lazy); t->r->sum = g(t->r->sum, p(t->lazy, count(t->r))); } } t->lazy = OM0; } return update(t); } Node *rotate(Node *t, bool b) { t = propagate(t); Node *s; if(b) { s = propagate(t->l); t->l = s->r; s->r = t; } else { s = propagate(t->r); t->r = s->l; s->l = t; } update(t); return update(s); } Node *submerge(Node *l, Node *r) { if(l->level < r->level) { r = propagate(r); Node *c = (r->l = submerge(l, r->l)); if(r->color == BLACK && c->color == RED && c->l && c->l->color == RED) { r->color = RED; c->color = BLACK; if(r->r->color == BLACK) return rotate(r, true); r->r->color = BLACK; } return update(r); } if(l->level > r->level) { l = propagate(l); Node *c = (l->r = submerge(l->r, r)); if(l->color == BLACK && c->color == RED && c->r && c->r->color == RED) { l->color = RED; c->color = BLACK; if(l->l->color == BLACK) return rotate(l, false); l->l->color = BLACK; } return update(l); } return alloc(l, r); } Node *merge(Node *l, Node *r) { if(!l || !r) return l ? l : r; Node *c = submerge(l, r); c->color = BLACK; return c; } pair< Node *, Node * > split(Node *t, int k) { if(!t) return {nullptr, nullptr}; t = propagate(t); if(k == 0) return {nullptr, t}; if(k >= count(t)) return {t, nullptr}; Node *l = t->l, *r = t->r; pool.free(t); if(k < count(l)) { auto pp = split(l, k); return {pp.first, merge(pp.second, r)}; } if(k > count(l)) { auto pp = split(r, k - count(l)); return {merge(l, pp.first), pp.second}; } return {l, r}; } Node *build(int l, int r, const vector< D > &v) { if(l + 1 >= r) return alloc(v[l]); return merge(build(l, (l + r) >> 1, v), build((l + r) >> 1, r, v)); } Node *build(const vector< D > &v) { //pool.clear(); return build(0, (int) v.size(), v); } void dump(Node *r, typename vector< D >::iterator &it, L lazy) { if(r->lazy != OM0) lazy = h(lazy, r->lazy); if(!r->l || !r->r) { *it++ = g(r->key, lazy); return; } dump(r->l, it, lazy); dump(r->r, it, lazy); } vector< D > dump(Node *r) { vector< D > v((size_t) count(r)); auto it = begin(v); dump(r, it, OM0); return v; } string to_string(Node *r) { auto s = dump(r); string ret; for(int i = 0; i < s.size(); i++) { ret += std::to_string(s[i]); ret += ", "; } return (ret); } void insert(Node *&t, int k, const D &v) { auto x = split(t, k); t = merge(merge(x.first, alloc(v)), x.second); } D erase(Node *&t, int k) { auto x = split(t, k); auto y = split(x.second, 1); auto v = y.first->key; pool.free(y.first); t = merge(x.first, y.second); return v; } D query(Node *&t, int a, int b) { auto x = split(t, a); auto y = split(x.second, b - a); auto ret = sum(y.first); t = merge(x.first, merge(y.first, y.second)); return ret; } void set_propagate(Node *&t, int a, int b, const L &pp) { auto x = split(t, a); auto y = split(x.second, b - a); y.first->lazy = h(y.first->lazy, pp); t = merge(x.first, merge(propagate(y.first), y.second)); } void set_element(Node *&t, int k, const D &x) { if(!t->l) { t->key = t->sum = x; return; } t = propagate(t); if(k < count(t->l)) set_element(t->l, k, x); else set_element(t->r, k - count(t->l), x); t = update(t); } int size(Node *t) { return count(t); } bool empty(Node *t) { return !t; } Node *makeset() { return (nullptr); } }; using mint=modint998244353; inline vector dd(vector a,vector b){ vector ret(5); rep(i,5) ret[i]=a[i]+b[i]; return ret; } inline vector dl(vector a,vector b){ vector ret(5); ret[0]=a[0]*b[0]+a[1]*b[1]; ret[1]=a[0]*b[2]+a[1]*b[3]; ret[2]=a[2]*b[0]+a[3]*b[1]; ret[3]=a[2]*b[2]+a[3]*b[3]; ret[4]=a[4]; return ret; //D } inline vector ll(vector a,vector b){ vector ret(4); ret[0]=a[0]*b[0]+a[1]*b[2]; ret[1]=a[0]*b[1]+a[1]*b[3]; ret[2]=a[2]*b[0]+a[3]*b[2]; ret[3]=a[2]*b[1]+a[3]*b[3]; return ret; } inline vector none(vector a,int b){ return a; } template class lazy_binary_trie { struct node { int cnt; U lazy; node *ch[2]; node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {} }; void push(node* t, int b) { if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]); if (t->ch[0]) t->ch[0]->lazy ^= t->lazy; if (t->ch[1]) t->ch[1]->lazy ^= t->lazy; t->lazy = 0; } node* add(node* t, U val, int b = B - 1) { if (!t) t = new node; t->cnt += 1; if (b < 0) return t; push(t, b); bool f = (val >> (U)b) & (U)1; t->ch[f] = add(t->ch[f], val, b - 1); return t; } node* sub(node* t, U val, int b = B - 1) { assert(t); t->cnt -= 1; if (t->cnt == 0) return nullptr; if (b < 0) return t; push(t, b); bool f = (val >> (U)b) & (U)1; t->ch[f] = sub(t->ch[f], val, b - 1); return t; } U get_min(node* t, U val, int b = B - 1) { assert(t); if (b < 0) return 0; push(t, b); bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f]; return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b); } U get(node* t, int k, int b = B - 1) { if (b < 0) return 0; push(t, b); int m = t->ch[0] ? t->ch[0]->cnt : 0; return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b); } int count_lower(node* t, U val, int b = B - 1) { if (!t || b < 0) return 0; push(t, b); bool f = (val >> (U)b) & (U)1; return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1); } node *root; public: lazy_binary_trie() : root(nullptr) {} int size() const { return root ? root->cnt : 0; } bool empty() const { return !root; } void insert(U val) { root = add(root, val); } void erase(U val) { root = sub(root, val); } void xor_all(U val) { if (root) root->lazy ^= val; } U max_element(U bias = 0) { return get_min(root, ~bias); } U min_element(U bias = 0) { return get_min(root, bias); } int lower_bound(U val) { // return id return count_lower(root, val); } int upper_bound(U val) { // return id return count_lower(root, val + 1); } U operator[](int k) { assert(0 <= k && k < size()); return get(root, k); } int count(U val) { if (!root) return 0; node *t = root; for (int i = B - 1; i >= 0; i--) { push(t, i); t = t->ch[(val >> (U)i) & (U)1]; if (!t) return 0; } return t->cnt; } }; using T = RedBlackTree,vector,dd,dl,ll,none>; T RBT(2000000,vector{0,0,0,0,0},vector{1,0,0,1}); vector t; mint ans=0; int N; vector edge[100005]; int sub_size[100005]; mint pow2[100005]; lint A[100005]; int pt[100005]; vector> bt(100000); vector*> bt_pt(100000); bool merge(int left,int right){ int rev=false; if(RBT.size(t[left])> adds; int N=RBT.size(t[right]); rep(i,N){ //key未満 vector now=RBT.query(t[right],i,i+1); int key=now[4].val(); int hi=bt_pt[left]->lower_bound(key); vector lo_sum=RBT.query(t[left],0,hi); mint dp1=now[0]*lo_sum[0]; mint dp2=now[0]*lo_sum[1]+now[1]*lo_sum[0]; mint dp1_key=dp1*key; mint dp2_key=dp2*key; adds.push_back({dp1,dp2,dp1_key,dp2_key,key}); } int before=0; mint sum[2]; rep(i,N){ vector now=RBT.query(t[right],i,i+1); int key=now[4].val(); int hi=bt_pt[left]->lower_bound(key); if(before!=hi) RBT.set_propagate(t[left],before,hi,{sum[0],0,sum[1],sum[0]}); sum[0]+=now[0],sum[1]+=now[1]; before=hi; } if(before!=RBT.size(t[left])) RBT.set_propagate(t[left],before,RBT.size(t[left]),{sum[0],0,sum[1],sum[0]}); for(auto e:adds){ int key=e[4].val(); int hi=bt_pt[left]->lower_bound(key); if(bt_pt[left]->count(e[4].val())){ vector res=RBT.query(t[left],hi,hi+1); rep(i,4) res[i]+=e[i]; RBT.set_element(t[left],hi,res); } else{ RBT.insert(t[left],hi,e); bt_pt[left]->insert(e[4].val()); } } return rev; } void dfs(int now,int par){ for(auto e:edge[now]){ if(e==par) continue; dfs(e,now); sub_size[now]+=sub_size[e]; if(bt_pt[e]->count(0)){ vector zero=RBT.query(t[e],0,1); zero[0]+=pow2[sub_size[e]-1]; RBT.set_element(t[e],0,zero); } else{ RBT.insert(t[e],0,{pow2[sub_size[e]-1],0,0,0,0}); bt_pt[e]->insert(0); } merge(now,e); } ans+=RBT.query(t[pt[now]],0,RBT.size(t[pt[now]]))[3]*pow2[N-1-sub_size[now]+(now==0)]; } int main(void){ cin >> N; rep(i,N) cin >> A[i]; rep(i,N) sub_size[i]=1; pow2[0]=1; rep(i,100004) pow2[i+1]=pow2[i]*2; rep(i,100004) pt[i]=i; rep(i,N) bt_pt[i]=&bt[i]; rep(i,N){ T::Node* tt=RBT.makeset(); RBT.insert(tt,0,{1,1,A[i],A[i],A[i]}); t.push_back(tt); bt_pt[i]->insert(A[i]); } rep(i,N-1){ int u,v; cin >> u >> v; u--,v--; edge[u].push_back(v); edge[v].push_back(u); } dfs(0,-1); int sum1=0,sum2=0; rep(i,N){ sum1+=RBT.size(t[i]); sum2+=bt_pt[i]->size(); } cout << ans.val() << endl; }