from sys import stdin input=lambda :stdin.readline()[:-1] n=int(input()) edge=[[] for i in range(n)] for _ in range(n-1): a,b=map(lambda x:int(x)-1,input().split()) edge[a].append(b) edge[b].append(a) m=int(input()) a=list(map(lambda x:int(x)-1,input().split())) b=[0]*n for i in a: b[i]=1 d={} mem=[0]*n ans=0 for v in range(n): res=[] for u in edge[v]: if b[u]: res.append(u) if b[v]: res.append(v) if len(res)<=2: res.sort() res=tuple(res) if res not in d: d[res]=0 d[res]+=1 if len(res)>=1: ans+=1 mem[v]=res for v in range(n): ansv=ans if tuple([v]) in d: ansv-=d[tuple([v])] for u in edge[v]: if tuple([u]) in d: ansv-=d[tuple([u])] x=[u,v] x.sort() if tuple(x) in d: ansv-=d[tuple(x)] if len(mem[v])>=3: ansv-=1 elif len(mem[v])==2 and v not in mem[v]: ansv-=1 print(ansv)