m,n = map(int,input().split()) x = [0]+list(map(int,input().split()))+[m+1] s = 0; o = 998244353; z = pow(6,-1,o) for i in range(n+1): d = x[i+1]-x[i]-1; s += d*(d+1)%o*(2*d+1)%o*z%o print(s%o)