import sys input = sys.stdin.readline mod=998244353 M,N=map(int,input().split()) X=[0]+list(map(int,input().split()))+[M+1] ANS=0 for i in range(1,len(X)): k=X[i]-X[i-1] ANS+=k*(k-1)*(2*k-1)//6 ANS%=mod print(ANS)