MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() idx = 0 n = int(data[idx]) idx += 1 edges = [[] for _ in range(n+1)] for _ in range(n-1): u = int(data[idx]) v = int(data[idx+1]) edges[u].append(v) edges[v].append(u) idx += 2 A = list(map(int, data[idx:idx+n])) idx += n # Build the tree structure with parent pointers and children list root = 1 parent = [0] * (n + 1) children = [[] for _ in range(n + 1)] stack = [(root, 0)] while stack: u, p = stack.pop() parent[u] = p for v in edges[u]: if v != p: children[u].append(v) stack.append((v, u)) total_ans = 0 for b in range(30): # Compute a[u] for each node u (bit b) a = [0] * (n + 1) sum_xor = 0 for u in range(1, n+1): au = (A[u-1] >> b) & 1 a[u] = au sum_xor ^= au dp0 = [0] * (n + 1) dp1 = [0] * (n + 1) # Iterative post-order traversal stack = [(root, False)] while stack: u, visited = stack.pop() if not visited: stack.append((u, True)) for v in reversed(children[u]): stack.append((v, False)) else: # Initialize with node u's a value if a[u] == 0: curr0, curr1 = 1, 0 else: curr0, curr1 = 0, 1 for v in children[u]: v0 = dp0[v] v1 = dp1[v] # Disconnect: curr * v1 new0 = (curr0 * v1) % MOD new1 = (curr1 * v1) % MOD # Connect: curr * v0 and curr * v1, for both curr0 and curr1 con0 = (curr0 * v0 + curr1 * v1) % MOD con1 = (curr0 * v1 + curr1 * v0) % MOD new0 = (new0 + con0) % MOD new1 = (new1 + con1) % MOD curr0, curr1 = new0, new1 dp0[u], dp1[u] = curr0, curr1 ways = dp1[root] total_ans = (total_ans + (ways << b)) % MOD print(total_ans % MOD) if __name__ == "__main__": main()