import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) sys.setrecursionlimit(10**7) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### n,m = na() l = na() dp = [0]*(n+1) dp[0] = 1 mod = 998244353 for k in range(1,n+1): for j in l: if k-j>=0: dp[k]+=dp[k-j]*j%mod*(n+1)%mod dp[k]%=mod if k-j>=0: dp[k]-=dp[k-j]*(k-j)%mod dp[k]%=mod dp[k]*=pow(k,mod-2,mod) dp[k]%=mod print(dp[n]*pow(n+1,mod-2,mod)%mod)