n,m=map(int,input().split()) a=list(map(int,input().split())) e=[[] for i in range(n)] for i in range(n-1): u,v=map(int,input().split()) u-=1 v-=1 e[u]+=[v] e[v]+=[u] v=[0]*n p=[0]*n u1=[0]*n q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t in e[s]: if v[t]==0: p[t]=s q+=[t] else: c=[0]*(len(e[s])+1) for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[u1[t]]+=1 for i in range(len(e[s])+1): if c[i]==0: u1[s]=i break v[s]=0 q.pop() u2=[0]*n q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 c=[0]*(len(e[s])+1) for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[u1[t]]+=1 if s!=0: if u2[s]<=len(e[s]): c[u2[s]]+=1 for i in range(len(e[s])+1): if c[i]==0: g=i break for t in e[s]: if v[t]==0: if c[u1[t]]-1==0: u2[t]=u1[t] else: u2[t]=g q+=[t] else: v[s]=0 q.pop() f=[0]*n for y in a: f[y-1]^=1 g=0 for s in range(n): if f[s]: c=[0]*(len(e[s])+1) for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[u1[t]]+=1 if s!=0: if u2[s]<=len(e[s]): c[u2[s]]+=1 for i in range(len(e[s])+1): if c[i]==0: g^=i break if g==0: print(-1,-1) exit() for s in range(n): if f[s]: c=[0]*(len(e[s])+1) for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[u1[t]]+=1 if s!=0: if u2[s]<=len(e[s]): c[u2[s]]+=1 for i in range(len(e[s])+1): if c[i]==0: g^=i for t in e[s]: if p[s]!=t: if g^u1[t]==0: print(a.index(s+1)+1,t+1) exit() elif s!=0: if g^u2[s]==0: print(a.index(s+1)+1,t+1) exit()