import sys input = sys.stdin.readline mod=998244353 N,M=map(int,input().split()) A=list(map(int,input().split())) A=[0]+A+[0] now=0 for i in range(len(A)-1): if A[i]!=A[i+1]: now+=1 DP=[0]*(N+2) DP[now]=1 #print(DP) for i in range(M): NDP=[0]*(N+2) # same for j in range(N+2): NDP[j]+=DP[j]*j*(N+1-j)%mod # 二個減らす for j in range(N+2): if j-2>=0: NDP[j-2]+=DP[j]*(j*(j-1)//2)%mod # 二個増やす for j in range(N+2): if j+2