N = int(input()) G = [[] for _ in range(N)] for _ in range(N - 1): u,v = map(int,input().split()) u -= 1 v -= 1 G[u].append(v) G[v].append(u) A = list(map(int,input().split())) P = 998244353 ans = 0 for i in range(30): dp = [[0] * 2 for _ in range(N)] stack = [(0,-1),(~0,-1)] while stack: now,parent = stack.pop() if now < 0: now = ~now for v in G[now]: if v != parent: stack.append((v,now)) stack.append((~v,now)) continue u = (A[now] >> i) & 1 dp[now][u] = 1 a = dp[now][0] b = dp[now][1] for v in G[now]: if v == parent:continue c = a * dp[v][0] + a * dp[v][1] + b * dp[v][1] d = a * dp[v][1] + b * dp[v][1] + b * dp[v][0] c %= P d %= P a = c b = d dp[now][0] = a dp[now][1] = b #print(dp) ans += dp[0][1] * (1 << i) ans %= P #print(ans) print(ans)