n,m=map(int,input().split()) c=list(map(int,input().split())) e=[[] for i in range(n)] for i in range(n): u,v=map(int,input().split()) u-=1 v-=1 e[u]+=[v] e[v]+=[u] p=[0]*n v=[0]*n for i in range(n): if v[i]==0: s=i g=c[s] p[g-1]+=1 v[s]=1 q=[s] for s in q: for t in e[s]: if v[t]==0 and c[t]==g: v[t]=1 q+=[t] print(sum(v-(v>0) for v in p))