import sys input = sys.stdin.readline from collections import deque from operator import itemgetter from heapq import heappop,heappush N,M=map(int,input().split()) P=list(map(int,input().split())) E=[set() for i in range(N)] for i in range(N-1): a,b=map(int,input().split()) a-=1 b-=1 E[a].add(b) E[b].add(a) DIS=[1<<60]*N Q=deque() C=list(map(int,input().split())) for c in C: c-=1 Q.append(c) DIS[c]=0 while Q: x=Q.popleft() for to in E[x]: if DIS[to]>DIS[x]+1: DIS[to]=DIS[x]+1 Q.append(to) CQ=[] for i in range(N): if len(E[i])==1: CQ.append((-P[i],DIS[i],i)) CQ.sort() now=0 ANS=0 while CQ: p,d,ind=heappop(CQ) if d<=now: continue else: ANS-=p now+=1 to=list(E[ind])[0] E[ind].remove(to) E[to].remove(ind) if len(E[to])==1: heappush(CQ,(-P[to],DIS[to],to)) print(ANS)