n,m=map(int,input().split()) p=list(map(int,input().split())) e=[[] for i in range(n)] for i in range(n-1): u,v=map(int,input().split()) u-=1 v-=1 e[u]+=[v] e[v]+=[u] d=[len(e[i]) for i in range(n)] v=[0]*n c=list(map(int,input().split())) q1=[] for i in c: i-=1 v[i]=1 q1+=[i] from heapq import heappush,heappop q2=[] for i in range(n): if d[i]==1 and v[i]==0: heappush(q2,(-p[i],i)) ans=0 for _ in range(n): while len(q2)>0 and v[q2[0][1]]: heappop(q2) if len(q2)>0: a,i=heappop(q2) a=-a ans+=a v[i]=1 for j in e[i]: d[j]-=1 if d[j]==1: heappush(q2,(-p[j],j)) nq=[] for i in q1: for j in e[i]: if v[j]==0: v[j]=1 nq+=[j] q1=nq print(ans)