#include using ll = std::int64_t; void dfs( int v, int p, int &counter, std::vector &in, std::vector &out, std::vector &diff, const std::vector> &G, const std::string &S ){ in[v] = counter; counter += 1; diff[in[v]] = S[v - 1] == '1' ? 1 : -1; if(p != -1){ diff[in[v]] += diff[in[p]]; } for(auto &w : G[v]){ if(w != p){ dfs(w, v, counter, in, out, diff, G, S); } } out[v] = counter; } ll search( int x, int l, int r, const std::vector &diff, const std::vector &diff_group, const std::vector &incr, const int B ){ ll res = 0; while(l < r && l % B > 0){ if(diff[l] + incr[l / B] >= x){ res += 1; } l += 1; } while(l < r && r % B > 0){ if(diff[r - 1] + incr[(r - 1) / B] >= x){ res += 1; } r -= 1; } while(l < r){ int last = std::min(l + B, diff_group.size()); res += diff_group.begin() + last - std::lower_bound(diff_group.begin() + l, diff_group.begin() + last, x - incr[l / B]); l = last; } return res; } void add( int x, int l, int r, std::vector &diff, std::vector &diff_group, std::vector &incr, const int B ){ auto update_point = [&](int i){ int first = i / B * B, last = std::min(first + B, diff_group.size()); if(x == -1){ int j = std::lower_bound(diff_group.begin() + first, diff_group.begin() + last, diff[i]) - diff_group.begin(); diff_group[j] += x; }else{ int j = std::upper_bound(diff_group.begin() + first, diff_group.begin() + last, diff[i]) - diff_group.begin(); j -= 1; diff_group[j] += x; } diff[i] += x; }; while(l < r && l % B > 0){ update_point(l); l += 1; } while(l < r && r % B > 0){ update_point(r - 1); r -= 1; } l /= B; r /= B; while(l < r){ incr[l] += x; l += 1; } } ll dfs2( int v, int p, const std::vector &in, const std::vector &out, std::vector &diff, std::vector &diff_group, std::vector &incr, const std::vector> &G, const std::string &S, const int B ){ ll res = 0; res += search(1, 0, diff.size(), diff, diff_group, incr, B); int xv = S[v - 1] == '1' ? 1 : -1; for(auto &w : G[v]){ if(w != p){ int xw = S[w - 1] == '1' ? 1 : -1; add(xw, 0, in[w], diff, diff_group, incr, B); add(-xv, in[w], out[w], diff, diff_group, incr, B); add(xw, out[w], diff.size(), diff, diff_group, incr, B); res += dfs2(w, v, in, out, diff, diff_group, incr, G, S, B); add(-xw, 0, in[w], diff, diff_group, incr, B); add(xv, in[w], out[w], diff, diff_group, incr, B); add(-xw, out[w], diff.size(), diff, diff_group, incr, B); } } return res; } int main(){ std::cin.tie(nullptr); std::ios::sync_with_stdio(false); int N; std::cin >> N; std::vector> G(N + 1); for(int i=0;i> u >> v; G[u].emplace_back(v); G[v].emplace_back(u); } std::string S; std::cin >> S; std::vector in(N + 1), out(N + 1), diff(N, 0); { int counter = 0; dfs(1, -1, counter, in, out, diff, G, S); } const int B = std::sqrt(N); std::vector diff_group(diff), incr((N + B - 1) / B, 0); for(int i=0;i