import sys
sys.setrecursionlimit(10 ** 8)


N,M = map(int,input().split())
C = list(map(int,input().split()))

parent = list(range(N))
def find(i):
    if parent[i] == i:return i
    parent[i] = find(parent[i])
    return parent[i]
def unite(i,j):
    I = find(i)
    J = find(j)
    if I == J:return False
    parent[i] = J
    parent[I] = J
    return True

col = set(C)
count = 0
for _ in range(M):
    u,v = map(int,input().split())
    u -= 1
    v -= 1
    if C[u] != C[v]:continue
    if unite(u,v):
        count += 1
print(N - len(col) - count)