mod = 998244353 m, n = map(int, input().split()) x = [0] + list(map(int, input().split())) +[m + 1] def f(n): return n * (n + 1) * (n * 2 + 1) // 6 ans = 0 for i in range(len(x) - 1): a = x[i+1] - x[i] - 1 ans += f(a) ans %= mod print(ans)