import sys input = sys.stdin.readline from collections import deque from operator import itemgetter N,M=map(int,input().split()) P=list(map(int,input().split())) E=[[] for i in range(N)] for i in range(N-1): a,b=map(int,input().split()) a-=1 b-=1 E[a].append(b) E[b].append(a) DIS=[1<<50]*N Q=deque() C=list(map(int,input().split())) for c in C: c-=1 Q.append(c) DIS[c]=0 while Q: x=Q.popleft() for to in E[x]: if DIS[to]>DIS[x]+1: DIS[to]=DIS[x]+1 Q.append(to) C=[] for i in range(N): if len(E[i])==1: C.append((P[i],DIS[i])) C.sort(key=itemgetter(1)) C.sort(key=itemgetter(0),reverse=True) now=0 ANS=0 for p,d in C: if d<=now: continue else: ANS+=p now+=1 print(ANS)