MOD = 998244353 ntt_perm = [[0]] def ntt(A,g): logN = 0 while (1 << logN) < len(A) : logN += 1 N = 1 << logN A += [0] * (N-len(A)) while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]]) X = ntt_perm[logN] for i in range(N): if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i] i = 1 while i < N: q = pow(g,(MOD-1)//i//2,MOD) qj = 1 for j in range(0,N,i*2): qj = 1 for k in range(j,j+i): A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD qj = qj * q % MOD i *= 2 N = int(input()) Q = list(map(int,input().split())) E = [[] for i in range(N)] for i in range(N-1) : u,v = map(int,input().split()) E[u-1].append(v-1) E[v-1].append(u-1) cdep = [0] * N cp = [-1] * N cbfs = [0] * N def centroid_decomposition() : P = [-1] * N I = [0] for p in I : for e in E[p] : if P[p] == e : continue P[e] = p I.append(e) Z = [1] * N for p in I[-1:0:-1] : Z[P[p]] += Z[p] cdI = [(0,-1)] for s,par in cdI : while True : nx = -1 for e in E[s] : if Z[e] * 2 > Z[s] : nx = e if nx == -1 : break Z[s],Z[nx] = Z[s]-Z[nx],Z[s] s = nx cbfs.append(s) Z[s] = 0 if par != -1 : cdep[s],cp[s] = cdep[par]+1,par for e in E[s] : if Z[e] != 0 : cdI.append((e,s)) centroid_decomposition() NTTg = 3 invNTTg = pow(NTTg, MOD-2, MOD) max_ntt_size_log = 0 while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1 max_ntt_size_log += 1 max_ntt_size = 1 << max_ntt_size_log k0 = 1 for i in range(1, N+1) : k0 = k0 * i % MOD k0 = k0 * k0 % MOD inv_mod = [1] * (max_ntt_size+1) for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD C = [0] * (max_ntt_size+1) for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD nttC = [[] for i in range(max_ntt_size_log+1)] for d in range(max_ntt_size_log+1) : inv_ntt_size = pow(1<