MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() ptr = 0 n = int(data[ptr]) ptr += 1 A = list(map(int, data[ptr:ptr + n])) ptr += n # Build adjacency list, 1-based edges = [[] for _ in range(n + 1)] for _ in range(n - 1): u = int(data[ptr]) v = int(data[ptr + 1]) ptr += 2 edges[u].append(v) edges[v].append(u) ans = [0] * (n + 1) dp = [0] * (n + 1) stack = [] # Stack elements: (node, parent, visited) stack.append((1, 0, False)) # Initialize structures to track intermediate values sum_child_res = [0] * (n + 1) cross = [0] * (n + 1) sum_child_dp = [0] * (n + 1) prev_sum = [0] * (n + 1) while stack: u, parent, visited = stack.pop() if not visited: # First visit: push back with visited=True, then push children stack.append((u, parent, True)) children = [] for v in edges[u]: if v != parent: children.append(v) # Push children in reverse order to process them in order for v in reversed(children): stack.append((v, u, False)) # Initialize intermediate values for node u sum_child_res[u] = 0 cross[u] = 0 sum_child_dp[u] = 0 prev_sum[u] = 0 else: # Process the node after all children are processed children = [v for v in edges[u] if v != parent] for v in children: sum_child_res[u] = (sum_child_res[u] + ans[v]) % MOD # Update cross value cross[u] = (cross[u] + dp[v] * prev_sum[u]) % MOD prev_sum[u] = (prev_sum[u] + dp[v]) % MOD sum_child_dp[u] = (sum_child_dp[u] + dp[v]) % MOD # Calculate current_ans and current_dp term1 = cross[u] * A[u-1] % MOD term2 = sum_child_dp[u] * A[u-1] % MOD current_ans = (sum_child_res[u] + term1 + term2) % MOD ans[u] = current_ans current_dp = (A[u-1] * (1 + sum_child_dp[u])) % MOD dp[u] = current_dp print(ans[1] % MOD) if __name__ == "__main__": main()