N=int(input()) A=list(map(int,input().split())) G=[[] for i in range(N)] for i in range(N-1): a,b=map(int,input().split()) G[a-1].append(b-1) G[b-1].append(a-1) dist=[-1]*N dist[0]=0 from collections import deque S=deque() S.append(0) while S: x=S.pop() for y in G[x]: if dist[y]>=0: continue dist[y]=dist[x]+1 S.append(y) L=[] for i in range(N): L.append((dist[i],i)) mod=998244353 dp=[0]*N result=0 L.sort(reverse=True) k=pow(2,-1,mod) for i in range(N): x=L[i][1] count=0 p=A[x] for y in G[x]: if dist[y]