from heapq import heappush, heappop from collections import deque N, M = map(int, input().split()) P = list(map(int, input().split())) G = [[] for _ in range(N)] for _ in range(N-1): u, v = map(int, input().split()) u, v = u-1, v-1 G[u].append(v) G[v].append(u) C = list(map(int, input().split())) def encode(p, n): return p*N+n def decode(n): return n//N, n%N visited = [-1]*N used = [False]*N que = deque() for c in C: c -= 1 visited[c] = 0 que.append(c) heap = [] for i in range(N): if len(G[i]) == 1: heappush(heap, encode(10**9-P[i], i)) ans = 0 cnt = [len(G[i]) for i in range(N)] pre = -1 while que: n = que.popleft() if pre < visited[n]: while heap: p, idx = decode(heappop(heap)) p = 10**9-p if visited[idx] == -1 and not used[idx]: used[idx] = True ans += p for v in G[idx]: cnt[v] -= 1 if visited[v] == -1 and not used[v] and cnt[v] == 1: heappush(heap, encode(10**9-P[v], v)) break pre = visited[n] for v in G[n]: if visited[v] == -1 and not used[v]: visited[v] = visited[n]+1 que.append(v) while heap: p, idx = decode(heappop(heap)) p = 10**9-p if visited[idx] == -1 and not used[idx]: used[idx] = True ans += p for v in G[idx]: cnt[v] -= 1 if visited[v] == -1 and not used[v] and cnt[v] == 1: heappush(heap, encode(10**9-P[v], v)) print(ans)