#include using namespace std; #include using mint=atcoder::modint998244353; int main(){ int n;cin>>n; vector a(n);for(auto&e:a){int x;cin>>x;e=x;} vector> g(n); for(int i=0;i>u>>v; u--;v--; g[u].push_back(v); g[v].push_back(u); } mint ans=0; auto dfs=[&](auto dfs,int p,int prev=-1)->mint { mint sm1=0,sm2=0; for(auto&e:g[p]){ if(e==prev)continue; mint esm=dfs(dfs,e,p); sm1+=esm; sm2+=esm*esm; } ans+=a[p]*(sm1+(sm1*sm1-sm2)/2); return a[p]*(sm1+1); }; dfs(dfs,0); cout<