N=int(input()) A=list(map(int, input().split())) D=[[] for i in range(N)] for i in range(N-1): u,v=map(int, input().split()) u-=1;v-=1 D[u].append(v);D[v].append(u) from collections import deque d=deque() V=[-1]*N d.append(0);V[0]=1 C=[] while d: now=d.popleft() C.append(now) for nex in D[now]: if V[nex]==-1: V[nex]=1 d.append(nex) mod=998244353 dp=[[0]*2 for i in range(N)] for now in C[::-1]: p=0;q=1;r=0 for nex in D[now]: d=dp[nex][0] p+=d;q*=d;r+=d**2 q%=mod;r%=mod dp[now][0]=A[now]+A[now]*p dp[now][1]=(p**2-r)*pow(2,-1,mod)*A[now] dp[now][0]%=mod dp[now][1]%=mod ans=0 for i in range(N): x,y=dp[i] ans+=x ans+=y ans-=A[i] ans%=mod print(ans)