N, M = map(int, input().split()) A = list(map(int, input().split())) MOD = 998244353 cnt = 0 f = 0 for i in A: if i != f: cnt += 1 f = i cnt += f dp = [0]*(N+2) dp[cnt] = 1 inv2 = (MOD+1)//2 for _ in range(M): ndp = [0] * (N+2) for i in range(N+2): if i-2 >= 0: ndp[i-2] += (dp[i]*i*(i-1)*inv2) % MOD ndp[i] += (dp[i]*i*(N+1-i)) % MOD if i+2 <= N+1: ndp[i+2] += (dp[i]*(N+1-i)*(N-i)*inv2) % MOD for i in range(N+2): dp[i] = ndp[i] % MOD print(dp[0])