n = int(input()) adj = [[] for _ in range(n)] for _ in range(n-1): a,b = map(int,input().split()) adj[a-1].append(b-1); adj[b-1].append(a-1) m = int(input()) a = {v-1 for v in map(int,input().split())} c = [0]*n; d = [0]*n for p in range(n): c[p] = p in a for v in adj[p]: c[p] += v in a for p in range(n): if p in a: for v in adj[p]: d[p] += c[v]==1 s = sum(c[p]>0 for p in range(n)) for p in range(n): t = s-(c[p]>0) for v in adj[p]: if c[v]: t -= c[v]-(p in a)-(v in a)==0 if v in a: t -= d[v]-(c[p]==1) print(t)