import sys from sys import stdin from collections import defaultdict def main(): N, M = map(int, stdin.readline().split()) edges = defaultdict(list) S = set() for _ in range(M): B, C = map(int, stdin.readline().split()) S.add(B) S.add(C) edges[B].append(C) max_reachable = dict() def compute(x): if x not in S: return x if x in max_reachable: return max_reachable[x] # Temporarily set to x to prevent infinite recursion max_reachable[x] = x current_max = x for c in edges[x]: candidate = compute(c) if candidate > current_max: current_max = candidate max_reachable[x] = current_max return current_max sum_diff = 0 for x in S: compute(x) sum_diff += (max_reachable[x] - x) total = sum_diff + N * (N + 1) // 2 print(total) if __name__ == "__main__": main()