import sys MOD = 998244353 def main(): sys.setrecursionlimit(1 << 25) N = int(sys.stdin.readline()) edges = [[] for _ in range(N+1)] for _ in range(N-1): u, v = map(int, sys.stdin.readline().split()) edges[u].append(v) edges[v].append(u) A = list(map(int, sys.stdin.readline().split())) A = [0] + A # Convert to 1-based indexing # Build the tree structure with parent and children using BFS parent = [0] * (N + 1) children = [[] for _ in range(N + 1)] root = 1 stack = [(root, 0)] while stack: u, p = stack.pop() parent[u] = p for v in edges[u]: if v != p: children[u].append(v) stack.append((v, u)) ans = 0 for b in range(30): # Precompute s_b for each node s = [0] * (N + 1) for u in range(1, N + 1): s[u] = (A[u] >> b) & 1 dp0 = [0] * (N + 1) dp1 = [0] * (N + 1) # Iterative post-order traversal stack = [(root, False)] while stack: u, visited = stack.pop() if not visited: stack.append((u, True)) # Push children in reverse order to process them left to right for v in reversed(children[u]): stack.append((v, False)) else: # Initialize current DP with the value of the node's s if s[u] == 0: curr0, curr1 = 1, 0 else: curr0, curr1 = 0, 1 # Merge each child's DP for v in children[u]: v0 = dp0[v] v1 = dp1[v] # Cut the edge to child (contributes v's 1) new0_cut = curr0 * v1 % MOD new1_cut = curr1 * v1 % MOD # Not cut the edge (merge with child's parity) nc0 = (curr0 * v0 + curr1 * v1) % MOD nc1 = (curr0 * v1 + curr1 * v0) % MOD # Update current DP curr0 = (new0_cut + nc0) % MOD curr1 = (new1_cut + nc1) % MOD dp0[u] = curr0 dp1[u] = curr1 # Add the contribution of this bit ans = (ans + (dp1[root] % MOD) * ((1 << b) % MOD)) % MOD print(ans % MOD) if __name__ == '__main__': main()