import sys sys.setrecursionlimit(10 ** 8) N,M = map(int,input().split()) C = list(map(int,input().split())) parent = list(range(N)) def find(i): if parent[i] == i:return i parent[i] = find(parent[i]) return parent[i] def unite(i,j): I = find(i) J = find(j) if I == J:return False parent[i] = J parent[I] = J return True col = set(C) count = 0 for _ in range(M): u,v = map(int,input().split()) u -= 1 v -= 1 if C[u] != C[v]:continue if unite(u,v): count += 1 print(N - len(col) - count)