MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() M = int(data[0]) N = int(data[1]) A = list(map(int, data[2:2+N])) from collections import defaultdict freq = defaultdict(int) for a in A: freq[a] += 1 Cs = list(freq.values()) K_min = max(Cs) if Cs else 0 if K_min == 0: print(0) return max_t = N size = max_t + 2 fact = [1] * (size) inv_fact = [1] * (size) for i in range(1, size): fact[i] = fact[i-1] * i % MOD inv_fact[size-1] = pow(fact[size-1], MOD-2, MOD) for i in range(size-2, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if n < 0 or k < 0 or k > n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD def P(t, c): if t < c: return 0 return fact[t] * inv_fact[t - c] % MOD f = [0] * (max_t + 2) for t in range(K_min, max_t + 1): prod = 1 for c in Cs: prod = prod * P(t, c) % MOD f[t] = prod total = 0 for K in range(K_min, N + 1): res = 0 for d in range(0, K - K_min + 1): t = K - d if t < K_min: continue term = f[t] * pow(-1, d, MOD) % MOD term = term * comb(K, d) % MOD res = (res + term) % MOD total = (total + res) % MOD print(total % MOD) if __name__ == '__main__': main()