import java.io.BufferedReader; import java.io.InputStreamReader; public class Main { public static void main(String[] args) throws Exception { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); String[] sa = br.readLine().split(" "); long m = Long.parseLong(sa[0]); int n = Integer.parseInt(sa[1]); sa = br.readLine().split(" "); long[] x = new long[n + 2]; for (int i = 0; i < n; i++) { x[i + 1] = Long.parseLong(sa[i]); } br.close(); x[n + 1] = m + 1; int mod = 998244353; long m6 = modinv(6, mod); long ans = 0; for (int i = 1; i < x.length; i++) { long d = x[i] - x[i - 1] - 1; long val = d * (d + 1) % mod * (2 * d + 1) % mod * m6 % mod; ans += val; } ans %= mod; System.out.println(ans); } static long modinv(long a, int m) { long b = m; long u = 1; long v = 0; long tmp = 0; while (b > 0) { long t = a / b; a -= t * b; tmp = a; a = b; b = tmp; u -= t * v; tmp = u; u = v; v = tmp; } u %= m; if (u < 0) u += m; return u; } }