M, N = map(int, input().split()) X = list(map(int, input().split())) MOD = 998244353 def inv(x): return pow(x, MOD-2, MOD) def sm(x): res = x * (x+1) * (2*x+1) res *= inv(6) res %= MOD return res L = [] prev = 0 for i in range(N): x = X[i] if x - prev - 1 > 0: L.append(x - prev - 1) prev = x x = M+1 if x - prev - 1 > 0: L.append(x - prev - 1) # print(L) ans = 0 for item in L: ans += sm(item) ans %= MOD print(ans)