import sys, time, random from collections import deque, Counter, defaultdict def debug(*x):print('debug:',*x, file=sys.stderr) input = lambda: sys.stdin.readline().rstrip() ii = lambda: int(input()) mi = lambda: map(int, input().split()) li = lambda: list(mi()) inf = 2 ** 61 - 1 mod = 998244353 from collections import deque def TreeDepth(s, graph): inf = 2 ** 61 - 1 n = len(graph) depth = [inf] * n depth[s] = 0 q = deque() q.append(s) while q: now = q.popleft() for to in graph[now]: if depth[to] == inf: depth[to] = depth[now] + 1 q.append(to) return depth def TreeOrder(s, graph): dist = TreeDepth(s, graph) n = len(graph) l = list(range(n)) l.sort(key=lambda x: dist[x]) return l def subTree(s, graph): l = TreeOrder(s, graph) n = len(graph) sub = [0] * n for v in l[::-1]: sub[v] = 1 for to in graph[v]: sub[v] += sub[to] return sub def Treeheight(s, graph): l = TreeOrder(s, graph) n = len(graph) height = [0] * n for v in l[::-1]: height[v] = max([height[to] for to in graph[v]] + [0]) + 1 return height def EulerTour(s, graph): n = len(graph) done = [0] * n Q = [~s, s] # 根をスタックに追加 ET = [] while Q: i = Q.pop() if i >= 0: # 行きがけの処理 done[i] = 1 ET.append(i) for a in graph[i][::-1]: if done[a]: continue Q.append(~a) # 帰りがけの処理をスタックに追加 Q.append(a) # 行きがけの処理をスタックに追加 else: # 帰りがけの処理 ET.append(~i) return ET n = ii() a = li() graph = [[] for _ in range(n)] for _ in range(n - 1): u, v = mi() u -= 1 v -= 1 graph[u].append(v) graph[v].append(u) L = TreeOrder(0, graph) d = TreeDepth(0, graph) dp = [1] * n ans = 0 i2 = pow(2, -1, mod) for v in L[::-1]: s = 0 nans = 0 dp[v] = a[v] f = 0 for to in graph[v]: if d[to] > d[v]: f = 1 dp[v] += a[v] * dp[to] nans -= a[v] * dp[to] % mod * dp[to] dp[v] %= mod nans %= mod s += dp[to] nans += a[v] * s * s nans *= i2 nans %= mod ans += dp[v] + nans - a[v] ans %= mod if not f: dp[v] = a[v] print(ans)