import sys int1 = lambda x: int(x) - 1 # input = lambda: sys.stdin.buffer.readline() input = lambda: sys.stdin.readline().rstrip() ii = lambda: int(input()) i1 = lambda: int1(input()) mi = lambda: map(int, input().split()) mi1 = lambda: map(int1, input().split()) li = lambda: list(mi()) li1 = lambda: list(mi1()) lli = lambda n: [li() for _ in range(n)] INF = float("inf") mod = int(1e9 + 7) # mod = 998244353 n, m = mi() g = [list() for i in range(n)] rg = [list() for i in range(n)] for i in range(m): a, b = mi1() g[a].append(b) rg[b].append(a) st = [] order = [] used = [False] * n for i in range(n): st.append(i) while st: cur = st.pop() if cur < 0: order.append(-cur - 1) continue if used[cur]: continue used[cur] = True st.append(-cur - 1) for to in g[cur]: st.append(to) cmp = [-1] * n c = 0 for v in order[::-1]: if cmp[v] != -1: continue st.append(v) while st: cur = st.pop() if cmp[cur] != -1: continue cmp[cur] = c for to in rg[cur]: st.append(to) c += 1 IN = [0] * c OUT = [0] * c for cur in range(n): for to in g[cur]: if cmp[to] != cmp[cur]: OUT[cmp[cur]] += 1 IN[cmp[to]] += 1 if c == 1: print(0) else: print(max(IN.count(0), OUT.count(0)))