import sys sys.setrecursionlimit(10 ** 9) import pypyjit pypyjit.set_param('max_unroll_recursion=-1') MOD = 998244353 n = int(input()) A = list(map(int, input().split())) edges = [[] for _ in range(n)] for _ in range(n - 1): u, v = map(int, input().split()) u -= 1 v -= 1 edges[u].append(v) edges[v].append(u) B = [(a, i) for i, a in enumerate(A)] B.sort(reverse = True) used = [False] * n pow2 = [1] for _ in range(n): pow2.append(pow2[-1] * 2 % MOD) ans = 0 for a, i in B: dist = [-1] * n size = [0] * n dist[i] = 0 stack = [~i, i] while stack: pos = stack.pop() if pos >= 0: for npos in edges[pos]: if dist[npos] == -1: dist[npos] = dist[pos] + 1 stack.append(~npos) stack.append(npos) else: pos = ~pos size[pos] += 1 for npos in edges[pos]: size[pos] += size[npos] dp = [0] * n def dfs(pos, bpos): ret = 1 for npos in edges[pos]: if npos == bpos: continue if used[npos]: ret *= pow2[size[npos] - 1] else: dfs(npos, pos) ret *= dp[npos] + pow2[size[npos] - 1] ret %= MOD dp[pos] = ret dfs(i, -1) tot = 0 def dfs2(pos, bpos): global tot L = [1] for npos in edges[pos]: if used[npos]: L.append(L[-1] * pow2[size[npos] - 1] % MOD) elif npos == bpos: L.append(L[-1] * dp[npos] % MOD) else: L.append(L[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD) tot += L[-1] R = [1] for npos in edges[pos][::-1]: if used[npos]: R.append(R[-1] * pow2[size[npos] - 1] % MOD) elif npos == bpos: R.append(R[-1] * dp[npos] % MOD) else: R.append(R[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD) R = R[::-1] for ii, npos in enumerate(edges[pos]): if npos == bpos or used[npos]: continue dp[pos] = L[ii] * R[ii + 1] % MOD size[pos] = n - size[npos] dfs2(npos, pos) dfs2(i, -1) ans += tot * a ans %= MOD used[i] = True print(ans)