#include #include #include #include struct CentroidDecomposition { int NO_PARENT = -1; int V; int E; std::vector>> to; // (node_id, edge_id) std::vector par; // parent node_id par[root] = -1 std::vector> chi; // children id's std::vector subtree_size; // size of each subtree std::vector available_edge; // If 0, ignore the corresponding edge. CentroidDecomposition(int v = 0) : V(v), E(0), to(v), par(v, NO_PARENT), chi(v), subtree_size(v) {} CentroidDecomposition(const std::vector> &to_) : CentroidDecomposition(to_.size()) { for (int i = 0; i < V; i++) { for (auto j : to_[i]) { if (i < j) { add_edge(i, j); } } } } void add_edge(int v1, int v2) { to[v1].emplace_back(v2, E), to[v2].emplace_back(v1, E), E++; available_edge.emplace_back(1); } int _dfs_fixroot(int now, int prv) { chi[now].clear(), subtree_size[now] = 1; for (auto nxt : to[now]) { if (nxt.first != prv and available_edge[nxt.second]) { par[nxt.first] = now, chi[now].push_back(nxt.first); subtree_size[now] += _dfs_fixroot(nxt.first, now); } } return subtree_size[now]; } void fix_root(int root) { par[root] = NO_PARENT; _dfs_fixroot(root, -1); } //// Centroid Decpmposition //// std::vector centroid_cand_tmp; void _dfs_detect_centroids(int now, int prv, int n) { bool is_centroid = true; for (auto nxt : to[now]) { if (nxt.first != prv and available_edge[nxt.second]) { _dfs_detect_centroids(nxt.first, now, n); if (subtree_size[nxt.first] > n / 2) is_centroid = false; } } if (n - subtree_size[now] > n / 2) is_centroid = false; if (is_centroid) centroid_cand_tmp.push_back(now); } std::pair detect_centroids(int r) { // ([centroid_node_id1], ([centroid_node_id2]|-1)) centroid_cand_tmp.clear(); while (par[r] != NO_PARENT) r = par[r]; int n = subtree_size[r]; _dfs_detect_centroids(r, -1, n); if (centroid_cand_tmp.size() == 1) return std::make_pair(centroid_cand_tmp[0], -1); else return std::make_pair(centroid_cand_tmp[0], centroid_cand_tmp[1]); } std::vector _cd_vertices; void _centroid_decomposition(int now) { fix_root(now); now = detect_centroids(now).first; _cd_vertices.emplace_back(now); /* do something */ for (auto p : to[now]) { int nxt, eid; std::tie(nxt, eid) = p; if (available_edge[eid] == 0) continue; available_edge[eid] = 0; _centroid_decomposition(nxt); } } std::vector centroid_decomposition(int x) { _cd_vertices.clear(); _centroid_decomposition(x); return _cd_vertices; } }; #include #include // 0-indexed BIT (binary indexed tree / Fenwick tree) (i : [0, len)) template struct BIT { int n; std::vector data; BIT(int len = 0) : n(len), data(len) {} void reset() { std::fill(data.begin(), data.end(), T(0)); } void add(int pos, T v) { // a[pos] += v pos++; while (pos > 0 and pos <= n) data[pos - 1] += v, pos += pos & -pos; } T sum(int k) const { // a[0] + ... + a[k - 1] T res = 0; while (k > 0) res += data[k - 1], k -= k & -k; return res; } T sum(int l, int r) const { return sum(r) - sum(l); } // a[l] + ... + a[r - 1] template friend OStream &operator<<(OStream &os, const BIT &bit) { T prv = 0; os << '['; for (int i = 1; i <= bit.n; i++) { T now = bit.sum(i); os << now - prv << ',', prv = now; } return os << ']'; } }; #include int main() { int N; std::cin >> N; CentroidDecomposition cd(N); for (int e = 0; e < N - 1; ++e) { int u, v; std::cin >> u >> v; --u, --v; cd.add_edge(u, v); } std::vector V(N); { std::string S; std::cin >> S; for (int i = 0; i < N; ++i) { V.at(i) = (S.at(i) - '0') * 2 - 1; } } std::vector is_alive(N, 1); long long ret = 0; std::vector cnt(N * 2 + 1); BIT bit(N * 2 + 1); for (int c : cd.centroid_decomposition(0)) { is_alive.at(c) = 0; int lo = 0, hi = 0; auto addbit = [&](int w) -> void { lo = std::min(lo, w); hi = std::max(hi, w); cnt.at(N + w)++; bit.add(N + w, 1); }; addbit(V.at(c)); if (V.at(c) > 0) ret++; for (auto [nxt, _] : cd.to.at(c)) { if (!is_alive.at(nxt)) continue; std::vector ws; auto dfs = [&](auto &&self, int now, int prv, int w) -> void { ws.push_back(w); // addbit(w); ret += bit.sum(-w + N + 1, N * 2 + 1); for (auto [nxt, _] : cd.to.at(now)) { if (nxt == prv or !is_alive.at(nxt)) continue; self(self, nxt, now, w + V.at(nxt)); } }; dfs(dfs, nxt, c, V.at(nxt)); for (int w : ws) { w += V.at(c); addbit(w); } } for (int i = lo; i <= hi; ++i) { int j = i + N; bit.add(j, -cnt.at(j)); cnt.at(j) = 0; } } std::cout << ret << '\n'; }