n, m = map(int, input().split()) max_replace = {} for _ in range(m): b, c = map(int, input().split()) if c > b: if b in max_replace: if c > max_replace[b]: max_replace[b] = c else: max_replace[b] = c visited = set() final_values = {} total_delta = 0 for b in max_replace.keys(): if b not in visited: path = [] current = b while current in max_replace and current not in visited: visited.add(current) path.append(current) current = max_replace[current] # Determine final value final = final_values.get(current, current) # Update final_values for all nodes in path for node in path: final_values[node] = final total_delta += final - node initial_sum = n * (n + 1) // 2 print(initial_sum + total_delta)