import sys sys.setrecursionlimit(10 ** 8) N,M = map(int,input().split()) Gnum = [0] * (N + 1) parent = list(range(N + 1)) def find(i): if parent[i] == i:return i parent[i] = find(parent[i]) return parent[i] def unite(i,j): I = find(i) J = find(j) if I == J:return False parent[i] = J parent[I] = J return True for _ in range(M): u,v = map(int,input().split()) Gnum[u] -= 1 Gnum[v] += 1 unite(u,v) ans = 0 from collections import defaultdict d = defaultdict(list) s = set() for i in range(1,N + 1): p = find(i) s.add(p) d[p].append(Gnum[i]) ans += len(s) - 1 for l in d.values(): tmp = 0 f = 0 for u in l: if u > 0: if f == 0: tmp += u - 1 f = 1 else: tmp += u ans += tmp print(ans)