m,n=map(int,input().split()) x=list(map(int,input().split())) ans=0 x=[0]+x mod=998244353 def f(k): return ((k*(k+1)*(2*k+1))//6)%mod for i in range(1,n+1): ans+=f(x[i]-x[i-1]-1) ans%mod ans+=f(m-x[-1]) print(ans%mod)