m,n=map(int,input().split()) x=map(int,input().split()) mod=998244353 ans=0 before=0 for xi in x: if xi-before-1==0: before=xi continue n=xi-before-1 # print(before, xi, n) ans+=n*(n+1)*(2*n+1)//6 ans%=mod before=xi if m-before!=0: n=m-before # print(before, m, n) ans+=n*(n+1)*(2*n+1)//6 ans%=mod print(ans)