# import pypyjit # pypyjit.set_param('max_unroll_recursion=-1') from collections import defaultdict as dd S = input R = range P = print def I(): return int(S()) def M(): return map(int, S().split()) def L(): return list(M()) def O(): return list(map(int, open(0).read().split())) def yn(b): print("Yes" if b else "No") biga = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" smaa = "abcdefghijklmnopqrstuvwxyz" ctoi = lambda c: ord(c) - ord('a') itoc = lambda i: chr(ord('a') + i) inf = 10 ** 18 mod = 998244353 def acc(a): b = [0] for i in a: b.append(b[-1] + i) return b m,n = M();x = L() x=[0]+x+[m+1] ans = 0 for i in R(n+1): d=x[i+1]-x[i]-1 ans+=d*(d+1)*(2*d+1)//6 P(ans%998244353)