import sys input = sys.stdin.readline n = int(input()) edge = [[] for i in range(n)] for i in range(n-1): a,b = map(int,input().split()) a-=1 b-=1 edge[a].append(b) edge[b].append(a) m = int(input()) goast = list(map(lambda x:int(x)-1,input().split())) g = set(goast) effect = [0 for i in range(n)] for i in range(n): if i in g: for j in edge[i]: effect[j] += 1 effect[i] += 1 cnt = [0 for i in range(n)] for i in range(n): for j in edge[i]: if effect[j] == 1: cnt[i] += 1 a = sum(1 if effect[i] else 0 for i in range(n)) for i in range(n): ans = a temp = 0 if i in g: temp += 1 for j in edge[i]: if j in g: temp += 1 ans -= cnt[j] if effect[i] == 1: ans += 1 if i in g and effect[j] == 2: ans -= 1 if i not in g and effect[j] == 1: ans -= 1 else: if effect[j] == 1 and i in g: ans -= 1 if temp == effect[i]: ans -= 1 print(ans)