m,n=map(int,input().split()) x=list(map(int,input().split()))+[m+1] M=998244353 i6=pow(6,M-2,M) f=lambda n:i6*n*(n+1)*(2*n+1)%M now=0 g=0 for v in x: g+=f(v-1-now) g%=M now=v print(g)