import sys input = sys.stdin.buffer.readline n, q = map(int, input().split()) A = list(map(int, input().split())) mod = 998244353 dp = [[0] * (n + 1) for _ in range(2)] dp[0][0] = 1 now = 0 nxt = 1 for i in range(n): now = i % 2 nxt = (i + 1) % 2 for j in range(n + 1): dp[nxt][j] = 0 for j in range(n + 1): nj = j dp[nxt][nj] += dp[now][j] * (A[i] - 1) dp[nxt][nj] %= mod nj = j + 1 if nj <= n: dp[nxt][nj] += dp[now][j] dp[nxt][nj] %= mod B = map(int, input().split()) for t in B: print(dp[nxt][t] % mod)