def main(): import sys input = sys.stdin.read data = input().split() idx = 0 N = int(data[idx]) idx += 1 M = int(data[idx]) idx += 1 max_map = {} for _ in range(M): B = int(data[idx]) idx += 1 C = int(data[idx]) idx += 1 if B < C: if B in max_map: if C > max_map[B]: max_map[B] = C else: max_map[B] = C memo = {} def compute_max(x): if x not in max_map: return x if x in memo: return memo[x] next_x = max_map[x] m = compute_max(next_x) memo[x] = m return m sum_initial = N * (N + 1) // 2 sum_diff = 0 for x in max_map: m = compute_max(x) sum_diff += (m - x) print(sum_initial + sum_diff) if __name__ == "__main__": main()