import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### m, n = na() mod = 998244353 x = [0] + na() + [m + 1] s = 0 ans = 0 def f(s, h):# s^2 + (s + 1) ^ 2 + ... + (s + h - 1) ^ 2 return g(s + h - 1) - g(s - 1) def g(x): return x * (x + 1) * (2 * x + 1) // 6 % mod for i in range(n + 1): h = x[i + 1] - x[i] - 1 ans += g(h) ans %= mod print(ans)