m,n=map(int,input().split()) x=list(map(int,input().split())) x=[0]+x[:]+[m+1] mod=998244353 ans=0 for i in range(len(x)-1): N=x[i+1]-x[i]-1 ans+=N*(N+1)*(2*N+1)//6 ans%=mod print(ans)