from collections import deque MOD=998244353 def pairs(arr): add=0 squ=0 for i in arr: add+=i squ+=i*i return (add*add-squ)//2 N=int(input()) tree=[set() for i in range(N)] for i in range(N-1): U,V=map(int,input().split()) tree[U-1].add(V-1) tree[V-1].add(U-1) A=list(map(int,input().split())) top=[] vert=deque([0]) dp=[-1]*N while len(vert)>0: pos=vert.popleft() top.append(pos) dp[pos]=0 for i in tree[pos]: if dp[i]==-1: vert.append(i) for i in reversed(top): for j in tree[i]: dp[i]+=dp[j] dp[i]+=1 ans=0 for i in range(N): rem=-1 less=[] gre=[] for j in tree[i]: if dp[i]A[j]: gre.append(dp[j]) if rem!=-1: if A[i]A[rem]: gre.append(N-dp[i]) ans+=pairs(less)+pairs(gre) ans%=MOD print(ans)