import sys input = sys.stdin.readline from bisect import * from copy import deepcopy def compress(lst): ''' B: lstを座圧したリスト idx_to_val: indexから元の値を取得するリスト val_to_idx: 元の値からindexを取得する辞書 ''' B = [] val_to_idx = {} idx_to_val = deepcopy(lst) idx_to_val = list(set(idx_to_val)) idx_to_val.sort() for i in range(len(lst)): ind = bisect_left(idx_to_val, lst[i]) B.append(ind) for i in range(len(B)): val_to_idx[lst[i]] = B[i] return B, idx_to_val, val_to_idx class DirectedGraph(): def __init__(self, N): self.N = N self.G = [[] for i in range(N)] self.rG = [[] for i in range(N)] self.order = [] self.used1 = [0] * N self.used2 = [0] * N self.group = [-1] * N self.label = 0 self.seen = [0] * N self.Edge = set() def add_edge(self, u, v): #多重辺は排除する if (u, v) not in self.Edge: self.G[u].append(v) self.rG[v].append(u) self.Edge.add((u, v)) def dfs(self, s): stack = [~s, s] while stack: u = stack.pop() if u >= 0: if self.used1[u]: continue self.used1[u] = 1 for v in self.G[u]: if self.used1[v]: continue stack.append(~v) stack.append(v) else: u = ~u if self.seen[u]: continue self.seen[u]= 1 self.order.append(u) def rdfs(self, s, num): stack = [s] while stack: u = stack.pop() if u >= 0: self.used2[u] = 1 self.group[u] = num for v in self.rG[u]: if self.used2[v]: continue stack.append(v) def scc(self): for i in range(self.N): if self.used1[i]: continue self.dfs(i) for s in reversed(self.order): if self.used2[s]: continue self.rdfs(s, self.label) self.label += 1 return self.label, self.group def construct(self): nG = [set() for _ in range(self.label)] mem = [[] for i in range(self.label)] for s in range(self.N): now = self.group[s] for u in self.G[s]: if now == self.group[u]: continue nG[now].add(self.group[u]) mem[now].append(s) return nG, mem N, M = map(int, input().split()) B, C = [0] * M, [0] * M for i in range(M): B[i], C[i] = map(int, input().split()) D, itov, _ = compress(B + C) B, C = D[:M], D[M:] K = max(D) + 1 G = DirectedGraph(K) for i in range(M): G.add_edge(C[i], B[i]) L, _ = G.scc() nG, mem = G.construct() ans = N * (N + 1) // 2 dp = [0] * L for i in range(L): dp[i] = max(mem[i]) for i in range(L): for u in nG[i]: dp[u] = max(dp[u], dp[i]) for m in mem[i]: ans += max(0, itov[dp[i]] - itov[m]) print(ans)