from sys import setrecursionlimit
setrecursionlimit(10**7)
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n-1):
    u,v = map(int,input().split())
    adj[u-1].append(v-1); adj[v-1].append(u-1)
a = list(map(int,input().split()))
mod = 998244353; ans = 0
def dfs(v,p):
    v0 = a[v]>>i&1^1; v1 = a[v]>>i&1
    for c in adj[v]:
        if c==p: continue
        c0,c1 = dfs(c,v)
        v0,v1 = (v0*c0+v1*c1+v0*c1)%mod,(v0*c1+v1*c0+v1*c1)%mod
    return v0,v1
for i in range(30): ans += (1<<i)*dfs(0,-1)[1]
print(ans%mod)