import sys sys.setrecursionlimit(10**8) import pypyjit pypyjit.set_param('max_unroll_recursion=-1') MOD = 998244353 N = int(input()) A = [0] + list(map(int, input().split())) G = [list() for _ in range(N + 1)] for i in range(N - 1): u, v = map(int, input().split()) G[u].append(v) G[v].append(u) # DP[i] # 部分木iにおいて、f(i, 部分木i内の全頂点) の合計 DP = [0] * (N + 1) def dfs(pos, pre): for nex in G[pos]: if nex == pre: continue dfs(nex, pos) # 親に計上 DP[pos] += DP[nex] # 子どもの計上がすべて完了したら、+1 してから A[i] を掛ける DP[pos] = (DP[pos] + 1) * A[pos] % MOD return dfs(1, 0) # 全方位木DPを解く DP2 = [0] * (N + 1) def dfs2(pos, pre): # pos が根になったときの、DP が答え DP2[pos] = DP[pos] # 根の移動 for nex in G[pos]: if nex == pre: continue # 設定のバックアップ memo1, memo2 = DP[pos], DP[nex] # DP[pos] からは DP[nex] 分の寄与を削る DP[pos] -= A[pos] * DP[nex] DP[pos] %= MOD # DP[nex] には、DP[pos](更新後)の寄与を追加する DP[nex] += A[nex] * DP[pos] DP[nex] %= MOD dfs2(nex, pos) # リストアする DP[pos], DP[nex] = memo1, memo2 return dfs2(1, 0) ans = sum(DP2) - sum(A) ans *= pow(2, -1, MOD) ans %= MOD print(ans)