n,m=map(int,input().split()) a=list(map(int,input().split())) e=[[] for i in range(n)] for i in range(n-1): u,v=map(int,input().split()) u-=1 v-=1 e[u]+=[v] e[v]+=[u] v=[0]*n p=[0]*n u1=[0]*n c=[[0]*(len(e[i])+1) for i in range(n)] q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t in e[s]: if v[t]==0: p[t]=s q+=[t] else: for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[s][u1[t]]+=1 for i in range(len(e[s])+1): if c[s][i]==0: u1[s]=i break v[s]=0 q.pop() u2=[0]*n q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 if s!=0: if u2[s]<=len(e[s]): c[s][u2[s]]+=1 for i in range(len(e[s])+1): if c[s][i]==0: g=i break for t in e[s]: if v[t]==0: if u1[t]<=len(e[s]) and c[s][u1[t]]-1==0: u2[t]=u1[t] else: u2[t]=g q+=[t] else: v[s]=0 q.pop() f=[0]*n for y in a: f[y-1]^=1 g=0 for s in range(n): if f[s]: for i in range(len(e[s])+1): if c[s][i]==0: g^=i break if g==0: print(-1,-1) exit() for s in range(n): if f[s]: for i in range(len(e[s])+1): if c[s][i]==0: g^=i for t in e[s]: if p[s]!=t: if g^u1[t]==0: print(a.index(s+1)+1,t+1) exit() elif s!=0: if g^u2[s]==0: print(a.index(s+1)+1,t+1) exit()