n=int(input()) e=[[] for i in range(n)] for i in range(n-1): u,v=map(int,input().split()) e[u]+=[v] e[v]+=[u] a=list(map(int,input().split())) o=[] y=[0]*n g1=[[] for i in range(n)] g2=[[] for i in range(n)] p=[0]*n R=1 v=[0]*n q=[R] v[R]=1 for s in q: o+=[s] y[s]=len(o)-1 for t in e[s]: if v[t]==0: v[t]=1 q+=[t] p[t]=s e[t].remove(s) v=[0]*n q=[R] v[R]=1 for s in q: for t in e[s]: v[t]=1 q+=[t] g1[s]+=[y[t]] v=[0]*n q=[R] v[R]=1 for s in q: for t in e[s]: v[t]=1 q+=[t] g2[s]+=[g1[t][0],g1[t][-1]] if len(g1[t])>0 else [] B=317 st1=[0]*B*B st2=[0]*B st3=[-1]*B def SET(l,r,v): yl=l//B yr=r//B for y in [yl,yr]: if st3[y]!=-1: st1[y*B:y*B+B]=[st3[y]]*B st3[y]=-1 if yl==yr: st1[l:r+1]=[v]*(r-l+1) st2[yl]=sum(st1[yl*B:yl*B+B]) else: st1[l:yl*B+B]=[v]*(yl*B+B-1-l+1) st2[yl]=sum(st1[yl*B:yl*B+B]) st1[yr*B:r+1]=[v]*(r-yr*B+1) st2[yr]=sum(st1[yr*B:yr*B+B]) st2[yl+1:yr]=[v*B]*(yr-1-yl-1+1) st3[yl+1:yr]=[v]*(yr-1-yl-1+1) return def GET(l,r): a=0 yl=l//B yr=r//B for y in [yl,yr]: if st3[y]!=-1: st1[y*B:y*B+B]=[st3[y]]*B st3[y]=-1 if yl==yr: a+=sum(st1[l:r+1]) else: a+=sum(st1[l:yl*B+B]) a+=sum(st2[yl+1:yr]) a+=sum(st1[yr*B:r+1]) SET(l,r,0) return a for i in range(n): SET(i,i,a[o[i]]) Q=int(input()) for _ in range(Q): x=int(input()) a=0 if len(g1[x])>0: l,r=g1[x][0],g1[x][-1] a+=GET(l,r) if len(g2[x])>0: l,r=g2[x][0],g2[x][-1] a+=GET(l,r) if x!=R: px=p[x] l,r=g1[px][0],g1[px][-1] a+=GET(l,r) a+=GET(y[px],y[px]) if px!=R: px=p[px] a+=GET(y[px],y[px]) else: a+=GET(y[x],y[x]) print(a) SET(y[x],y[x],a)