import sys

input = sys.stdin.buffer.readline
n, q = map(int, input().split())
A = list(map(int, input().split()))

mod = 998244353
dp = [[0] * (n + 1) for _ in range(2)]
dp[0][0] = 1
now = 0
nxt = 1
for i in range(n):
    now = i % 2
    nxt = (i + 1) % 2
    for j in range(n + 1):
        dp[nxt][j] = 0
    for j in range(n + 1):
        nj = j
        dp[nxt][nj] += dp[now][j] * (A[i] - 1)
        dp[nxt][nj] %= mod
        nj = j + 1
        if nj <= n:
            dp[nxt][nj] += dp[now][j]
            dp[nxt][nj] %= mod


B = map(int, input().split())
for t in B:
    print(dp[nxt][t] % mod)