mod=998244353 M,N=map(int, input().split()) A=list(map(int, input().split())) A=[0]+A A.append(M+1) def f(x): a=x*(x+1)*(2*x+1)//6 return a ans=0 for i in range(len(A)-1): p=A[i+1]-A[i]-1 ans+=f(p) ans%=mod print(ans)