import sys input = sys.stdin.readline mod=998244353 N=int(input()) E=[[] for i in range(N)] for i in range(N-1): u,v=map(int,input().split()) u-=1 v-=1 E[u].append(v) E[v].append(u) A=list(map(int,input().split())) # 木のHL分解+LCA ROOT=0 QUE=[ROOT] Parent=[-1]*N Parent[ROOT]=N # ROOTの親を定めておく. Child=[[] for i in range(N)] TOP_SORT=[] # トポロジカルソート while QUE: # トポロジカルソートと同時に親を見つける x=QUE.pop() TOP_SORT.append(x) for to in E[x]: if Parent[to]==-1: Parent[to]=x Child[x].append(to) QUE.append(to) Children=[1]*N ANS=0 for x in TOP_SORT[::-1]: #(自分を含む)子ノードの数を調べる if x==ROOT: break Children[Parent[x]]+=Children[x] for i in range(N): if Child[i]==[]: continue #if Parent[i]==N: # continue LOW=[] HIGH=[] a=A[i] for c in Child[i]: if A[c]a: HIGH.append(Children[c]) p=Parent[i] if p!=N: if A[p]a: HIGH.append(N-Children[i]) #print(i,LOW,HIGH) S=sum(LOW) for l in LOW: S-=l ANS+=l*S S=sum(HIGH) for h in HIGH: S-=h ANS+=h*S ANS%=mod print(ANS)