mod = 998244353 N,M = map(int,input().split()) k = list(map(int,input().split())) K = max(k) cnt = [0 for i in range(K+1)] for i in range(N): cnt[k[i]] += 1 COMB = [0 for j in range(K+2)] comb = 1 for i in range(N-1): comb *= (M+N-1-i) * pow(i+1,mod-2,mod) % mod comb %= mod COMB[0] = comb for j in range(1,K+2): COMB[j] = COMB[j-1] * (M+N-1+j) * pow(N-1+j,mod-2,mod) % mod COMB[j] %= mod tmp = [0] * (K+3) tmp[-1] = 1 res = 0 for i in range(1,K+1): for j in range(-K-1,0): if not tmp[j]: continue tmp[j-1] = -j*tmp[j] % mod tmp[j] = 0 for j in range(-K-1,0)[::-1]: if not tmp[j]: continue tmp[j+1] -= tmp[j] tmp[j+1] %= mod for j in range(-K-1,0): res += tmp[j] * (COMB[-j] * cnt[i] % mod) % mod res %= mod print(res)