def main(): import sys input = sys.stdin.read data = input().split() idx = 0 N = int(data[idx]) idx += 1 M = int(data[idx]) idx +=1 ops_dict = {} for _ in range(M): B_i = int(data[idx]) idx +=1 C_i = int(data[idx]) idx +=1 if C_i > B_i: if B_i in ops_dict: if C_i > ops_dict[B_i]: ops_dict[B_i] = C_i else: ops_dict[B_i] = C_i sorted_B = sorted(ops_dict.keys(), reverse=True) dp = {} for x in sorted_B: dp[x] = x for x in sorted_B: c = ops_dict[x] if c in dp: new_val = max(x, dp[c]) else: new_val = max(x, c) if new_val > dp[x]: dp[x] = new_val initial_sum = N * (N + 1) // 2 for x in sorted_B: initial_sum += (dp[x] - x) print(initial_sum) if __name__ == "__main__": main()