import sys readline = sys.stdin.readline ns = lambda: readline().rstrip() ni = lambda: int(readline().rstrip()) nm = lambda: map(int, readline().split()) nl = lambda: list(map(int, readline().split())) def solve(): mod = 998244353 n, q = nm() a = nl() b = nl() dp = [[0]*(n+2) for i in range(2)] dp[0][0] = 1 for i in range(n): v = i & 1 for j in range(n+1): if dp[v][j]: dp[v^1][j+1] = (dp[v^1][j+1] + dp[v][j]) % mod dp[v^1][j] = (dp[v^1][j] + (a[i] - 1) * dp[v][j]) % mod dp[v][j] = 0 # print(dp[v^1]) for x in b: print(dp[n&1][x]) return solve()