N, M = map(int, input().split()) A = list(map(int, input().split())) MOD = 998244353 cnt = 0 f = 0 for i in A: if i != f: cnt += 1 f = i cnt += f dp = [0]*(N+2) dp[cnt] = 1 inv2 = (MOD+1)//2 for _ in range(M): ndp = [0] * (N+2) for i in range(N+2): if i-2 >= 0: ndp[i-2] = (ndp[i-2]+dp[i]*i % 998244353*(i-1) % 998244353*inv2) % 998244353 ndp[i] = (ndp[i]+dp[i]*i % 998244353*(N+1-i)) % 998244353 if i+2 <= N+1: ndp[i+2] = (ndp[i+2]+dp[i]*(N+1-i) % 998244353*(N-i) % 998244353*inv2) % 998244353 for i in range(N+2): dp[i] = ndp[i] % 998244353 print(dp[0])