n,m=map(int,input().split()) a=list(map(int,input().split())) e=[[] for i in range(n)] for i in range(n-1): u,v=map(int,input().split()) u-=1 v-=1 e[u]+=[v] e[v]+=[u] v=[0]*n p=[0]*n u1=[0]*n c=[[0]*(len(e[i])+1) for i in range(n)] q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t in e[s]: if v[t]==0: p[t]=s q+=[t] else: for t in e[s]: if v[t]==0 and u1[t]<=len(e[s]): c[s][u1[t]]+=1 for i in range(len(e[s])+1): if c[s][i]==0: u1[s]=i break v[s]=0 q.pop() u2=[0]*n q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 if s!=0: if u2[s]<=len(e[s]): c[s][u2[s]]+=1 for i in range(len(e[s])+1): if c[s][i]==0: g=i break for t in e[s]: if v[t]==0: if u1[t]0: for i in range(len(e[s])+1): if c[s][i]==0: g^=i for t in e[s]: if p[s]!=t: if g^u1[t]==0: print(a.index(s+1)+1,t+1) exit() else: if g^u2[s]==0: print(a.index(s+1)+1,t+1) exit()