m, n = map(int, input().split()) x = list(map(int, input().split())) + [m+1] def add(l, r): x = r - l return x*(x+1)*(2*x+1)//6 miss = 0 ans = 0 for v in x: ans += add(miss+1, v) ans %= 998244353 miss = v print(ans)