#include #include using namespace std; using namespace atcoder; #define rep(i,m,n,k) for (int i = (int)(m); i < (int)(n); i += (int)(k)) #define rrep(i,m,n,k) for (int i = (int)(m); i > (int)(n); i += (int)(k)) #define ll long long #define list(T,A,N) vector A(N);for(int i=0;i<(int)(N);i++){cin >> A[i];} using mint = modint998244353; tuple,vector, vector> sub_par_dist(vector> e, long long root){ long long N = e.size(); vector par(N,-1); vector sub(N,-1); vector dist(N,-1); queue v; dist[root] = 0; v.push(root); long long x; while (!v.empty()){ x = v.front();v.pop(); for (auto ix:e[x]){ if (dist[ix]!=-1) continue; dist[ix] = dist[x] + 1; v.push(ix); } } vector> H; for (long long i=0;i> N; vector> e(N); ll u,v; rep(_,0,N-1,1){ cin >> u >> v; u -= 1; v -= 1; e[u].push_back(v); e[v].push_back(u); } list(ll,A,N); auto [sub,par,dist] = sub_par_dist(e,0); mint ans = 0; mint sx,sx2,sy,sy2; rep(i,0,N,1){ sx = mint(0); sy = mint(0); sx2 = mint(0); sy2 = mint(0); for(ll ix:e[i]){ if(par[i]==ix){ if(A[ix]>A[i]){ sx += mint(N-sub[i]); sx2 += mint((N-sub[i])*(N-sub[i])); } else if(A[ix]A[i]){ sx += mint(sub[ix]); sx2 += mint(sub[ix]*sub[ix]); } else if(A[ix]