MOD = 998244353 N = int(input()) edges = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, input().split()) u -= 1 v -= 1 edges[u].append(v) edges[v].append(u) A = list(map(int, input().split())) st = [] st.append((0, 0)) dp = [[[0, 0] for i in range(N)] for b in range(31)] par = [-1] * N while st: now, t = st.pop() if t == 0: st.append((now, 1)) for nxt in edges[now]: if par[now] != nxt: st.append((nxt, 0)) par[nxt] = now else: # print(now) for b in range(31): if (A[now] >> b) & 1: dp[b][now][1] = 1 else: dp[b][now][0] = 1 for nxt in edges[now]: if nxt != par[now]: tmp0 = dp[b][nxt][0] * dp[b][now][0] + dp[b][nxt][1] * (dp[b][now][0] + dp[b][now][1]) tmp1 = dp[b][nxt][0] * dp[b][now][1] + dp[b][nxt][1] * (dp[b][now][0] + dp[b][now][1]) dp[b][now][0] = tmp0 % MOD dp[b][now][1] = tmp1 % MOD ans = 0 for b in range(31): ans += (1 << b) * dp[b][0][1] ans %= MOD print(ans)