#include using ll = std::int64_t; void point_update( int x, int i, std::vector &diff, std::vector> &diff_group, std::vector &incr, int &count, const int B ){ int ig = i / B; if(diff[i] + incr[ig] == 1 && x == -1){ count -= 1; }else if(diff[i] + incr[ig] == 0 && x == 1){ count += 1; } diff_group[ig][diff[i]] -= 1; diff[i] += x; diff_group[ig][diff[i]] += 1; } void add( int x, int l, int r, std::vector &diff, std::vector> &diff_group, std::vector &incr, int &count, const int B ){ while(l < r && l % B > 0){ point_update(x, l, diff, diff_group, incr, count, B); l += 1; } while(l < r && r % B > 0){ point_update(x, r - 1, diff, diff_group, incr, count, B); r -= 1; } l /= B; r /= B; while(l < r){ if(x == 1){ count += diff_group[l][-incr[l]]; }else{ count -= diff_group[l][1 - incr[l]]; } incr[l] += x; l += 1; } } int main(){ std::cin.tie(nullptr); std::ios::sync_with_stdio(false); int N; std::cin >> N; std::vector> G(N + 1); for(int i=0;i> u >> v; G[u].emplace_back(v); G[v].emplace_back(u); } std::string S; std::cin >> S; std::vector in(N + 1), out(N + 1), diff(N, 0); { int counter = 0; std::stack> stack; stack.emplace(0, 1, -1); while(!stack.empty()){ auto [t, v, p] = stack.top(); stack.pop(); if(t == 0){ in[v] = counter; counter += 1; diff[in[v]] = S[v - 1] == '1' ? 1 : -1; if(p != -1){ diff[in[v]] += diff[in[p]]; } stack.emplace(1, v, p); for(auto &w : G[v]){ if(w != p){ stack.emplace(0, w, v); } } }else{ out[v] = counter; } } } const int B = std::sqrt(N); int count = 0; std::vector> diff_group((N + B - 1) / B); std::vector incr((N + B - 1) / B, 0); for(int i=0;i= 1){ count += 1; } } ll res = 0; { std::stack> stack; stack.emplace(0, 1, -1); while(!stack.empty()){ auto [t, v, p] = stack.top(); stack.pop(); if(t == 0){ stack.emplace(1, v, p); if(p != -1){ int xv = S[v - 1] == '1' ? 1 : -1; int xp = S[p - 1] == '1' ? 1 : -1; add(xv, 0, in[v], diff, diff_group, incr, count, B); add(-xp, in[v], out[v], diff, diff_group, incr, count, B); add(xv, out[v], diff.size(), diff, diff_group, incr, count, B); } res += count; for(auto &w : G[v]){ if(w != p){ stack.emplace(0, w, v); } } }else{ if(p != -1){ int xv = S[v - 1] == '1' ? 1 : -1; int xp = S[p - 1] == '1' ? 1 : -1; add(-xv, 0, in[v], diff, diff_group, incr, count, B); add(xp, in[v], out[v], diff, diff_group, incr, count, B); add(-xv, out[v], diff.size(), diff, diff_group, incr, count, B); } } } } int n = std::count(S.begin(), S.end(), '1'); res -= n; res = res / 2 + n; std::cout << res << std::endl; }