## https://yukicoder.me/problems/no/2377 from collections import deque MOD = 998244353 def solve(N, next_nodes, bit_a_list): # (現役で繋がろうとしている成分の現時点の1xor, すでに完結している連結成分たちの1xor のAnd) # 0: (0, 0), 1: (0, 1), 2: (1, 0), 3: (1, 1) dp = [[0, 0, 0, 0] for _ in range(N)] parents = [-2] * N parents[0] = -1 stack = deque() stack.append((0, 0)) while len(stack) > 0: v, index = stack.pop() if index == 0: dp[v][bit_a_list[v] * 2 + 1] = 1 else: w = next_nodes[v][index - 1] new_dp = [0] * 4 for state in range(4): state_num = dp[w][state] c_num = state // 2 f_num = state % 2 for base_state in range(4): base_state_num = dp[v][base_state] base_c_num = base_state // 2 base_f_num = base_state % 2 # つなげる new_c_num = base_c_num ^ c_num new_f_num = f_num & base_f_num new_state = new_c_num * 2 + new_f_num new_dp[new_state] += (state_num * base_state_num) % MOD new_dp[new_state] %= MOD # 独立させる new_c_num = base_c_num new_f_num = f_num & base_f_num & c_num new_state = new_c_num * 2 + new_f_num new_dp[new_state] += (state_num * base_state_num) % MOD new_dp[new_state] %= MOD dp[v] = new_dp while index < len(next_nodes[v]): w = next_nodes[v][index] if w == parents[v]: index += 1 continue parents[w] = v stack.append((v, index + 1)) stack.append((w, 0)) break # 最後の仕上げ answer = 0 for state in range(4): c_num = state // 2 f_num = state % 2 ans = c_num & f_num if ans == 1: answer += dp[0][state] answer %= MOD return answer def main(): N = int(input()) next_nodes = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, input().split()) next_nodes[u - 1].append(v - 1) next_nodes[v - 1].append(u - 1) A = list(map(int, input().split())) # Aの各ビットごとに1の数を計算 max_a = max(A) k = 0 while (1 << k) < max_a: k += 1 max_k = k answer = 0 for k in range(max_k + 1): bit_a_list = [0] * N for i in range(N): a = A[i] bit_a_list[i] = 1 if a & (1 << k) > 0 else 0 count = solve(N, next_nodes, bit_a_list) answer += (count * (1 << k)) % MOD answer %= MOD print(answer) if __name__ == '__main__': main()