n,k = map(int,input().split()) a = list(map(int,input().split())) ans = 0 for i in range(2**n): now = [] cnt = 0 for j in range(n): if (i>>j & 1 == 1): cnt += 1 now.append(a[j]) if len(now) != k: continue total = sum(now) if total%998244353 <= total%998: ans += 1 ans %= 998 print(ans)