#include #include using namespace std; using namespace atcoder; using ll = long long; using vll = vector; using vvll = vector; using mint = modint998244353; using vb = vector; using vvb = vector; using vvvb = vector; #define rep(i,n) for(ll i=(ll)(0); i<(ll(n)); ++i) #define all(x) (x).begin(), (x).end() vll A; vvll G; vector> P; void dfs(ll n, ll p, ll d) { if (A[n] & (1ll << d))P[n][1] = 1; else P[n][0] = 1; for (auto v : G[n]) { if (v != p) { dfs(v, n, d); mint a = P[n][0] * (P[v][0] + P[v][1]) + P[v][1] * P[n][1]; mint b= P[n][1] * (P[v][0] + P[v][1]) + P[v][1] * P[n][0]; P[n][1] = b; P[n][0] = a; } } } int main() { ll N; cin >> N; G.resize(N); rep(i, N - 1) { ll U, V; cin >> U >> V; U--; V--; G[U].push_back(V); G[V].push_back(U); } A.resize(N); rep(i, N)cin >> A[i]; mint an = 0; rep(b, 31) { P.assign(N, vector(2, 0)); dfs(0, -1, b); an += (P[0][1]*(1ll<