import sys input = sys.stdin.readline MOD = 998244353 M, N = map(int, input().split()) X = list(map(int, input().split())) inv = pow(6, MOD-2, MOD) ans = 0 bk = 1 for x in X: n = x-bk ans += n*(n+1)%MOD*(2*n+1)%MOD*inv%MOD ans %= MOD bk = x+1 n = M+1-bk ans += n*(n+1)%MOD*(2*n+1)%MOD*inv%MOD ans %= MOD print(ans)