N,M = map(int,input().split()) from collections import defaultdict G = defaultdict(list) for _ in range(M): b,c = map(int,input().split()) G[c].append(b) l = sorted(G) memo = defaultdict(int) stack = [] dp = defaultdict(int) for v in reversed(l): if memo[v] == 1:continue dp[v] = v memo[v] = 1 stack = [v] while stack: now = stack.pop() for u in G[now]: if memo[u] == 0: stack.append(u) dp[u] = max(u,dp[v]) memo[u] = 1 ans = N * (N + 1) // 2 for k in dp: ans += dp[k] - k print(ans) #print(dp)