#if defined(LOCAL) #include #else #include #endif #include #pragma GCC optimize("Ofast") //#pragma GCC target("avx2") #pragma GCC optimize("unroll-loops") using namespace std; //#include //#include //namespace mp=boost::multiprecision; //#define mulint mp::cpp_int //#define mulfloat mp::cpp_dec_float_100 struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<=0;(i)--) #define flc(x) __builtin_popcountll(x) #define pint pair #define pdouble pair #define plint pair #define fi first #define se second #define all(x) x.begin(),x.end() //#define vec vector #define nep(x) next_permutation(all(x)) typedef long long lint; int dx[8]={1,1,0,-1,-1,-1,0,1}; int dy[8]={0,1,1,1,0,-1,-1,-1}; const int MAX_N=3e5+5; templatebool chmax(T &a,const T &b){if(abool chmin(T &a,const T &b){if(b bucket[MAX_N/1000]; //constexpr int MOD=1000000007; constexpr int MOD=998244353; #include using namespace atcoder; typedef __int128_t llint; using mint=modint998244353; int N; lint A[5005]; vector edge[5005]; mint dp1[5005][1005]; //頂点iを根とする部分木の、最大値がjになる切り方の総和 mint dp2[5005][1005]; //頂点iを根とする部分木の、最大値がjになる切り方のサイズの総和 int dp3[5005]; //部分木のサイズ mint pow2[5005]; void dfs(int now,int par){ for(auto child:edge[now]){ if(child==par) continue; dfs(child,now); //now-childの辺を切る場合、dp_sub[j]-1本の辺は自由になる //サイズ0,max0の切り方が2^(subsize-1)通りあると考える dp1[child][0]=pow2[dp3[child]-1]; dp2[child][0]=0; //切らない場合、mergeする mint merge1[1001],merge2[1001]; mint dp1sum_now[1002],dp2sum_now[1002],dp1sum_child[1002],dp2sum_child[1002]; rep(i,1001){ dp1sum_now[i+1]=dp1sum_now[i]+dp1[now][i]; dp2sum_now[i+1]=dp2sum_now[i]+dp2[now][i]; dp1sum_child[i+1]=dp1sum_child[i]+dp1[child][i]; dp2sum_child[i+1]=dp2sum_child[i]+dp2[child][i]; } rep(i,1001){ merge1[i]+=dp1sum_now[i+1]*dp1sum_child[i+1]; merge2[i]+=dp1sum_now[i+1]*dp2sum_child[i+1]; merge2[i]+=dp1sum_child[i+1]*dp2sum_now[i+1]; } reprev(i,1000){ merge1[i+1]-=merge1[i]; merge2[i+1]-=merge2[i]; } /*rep(i,501) rep(j,501){ merge1[max(i,j)]+=dp1[now][i]*dp1[child][j]; merge2[max(i,j)]+=dp1[now][i]*dp2[child][j]+dp1[child][j]*dp2[now][i]; }*/ rep(j,1001){ dp1[now][j]=merge1[j]; dp2[now][j]=merge2[j]; } dp3[now]+=dp3[child]; } } int main(void){ cin >> N; rep(i,N) cin>> A[i]; rep(i,N-1){ int u,v; cin >> u >> v; u--,v--; edge[u].push_back(v); edge[v].push_back(u); } pow2[0]=1; rep(i,5004) pow2[i+1]=pow2[i]*2; rep(i,N) dp1[i][A[i]]=1,dp2[i][A[i]]=1,dp3[i]=1; dfs(0,-1); mint ans=0; rep(i,N) rep(j,1001) ans+=dp2[i][j]*j*pow2[N-1-dp3[i]+(i==0)]; cout << ans.val() << endl; }