#include #include #include // Centroid Decomposition // Verification: Codeforces #190 Div.1 C https://codeforces.com/contest/321/submission/59093583 // find_current_centroids(int r): Enumerate centroid(s) of the subtree which `r` belongs to. struct CentroidDecomposition { int V; std::vector> to; private: std::vector is_alive; std::vector subtree_size; void decompose(int r, int conn_size, auto callback) { const int c = find_current_centroids(r, conn_size).first; is_alive.at(c) = 0; callback(c); for (int nxt : to.at(c)) { if (!is_alive.at(nxt)) continue; int next_size = subtree_size.at(nxt); if (subtree_size.at(nxt) > subtree_size.at(c)) next_size = subtree_size.at(r) - subtree_size.at(c); decompose(nxt, next_size, callback); } } public: CentroidDecomposition(int v = 0) : V(v), to(v), is_alive(v, 1), subtree_size(v) {} CentroidDecomposition(int v, const std::vector> &tree_edges) : CentroidDecomposition(v) { for (auto e : tree_edges) add_edge(e.first, e.second); } void add_edge(int v1, int v2) { assert(0 <= v1 and v1 < V and 0 <= v2 and v2 < V); assert(v1 != v2); to.at(v1).push_back(v2), to.at(v2).emplace_back(v1); } std::pair find_current_centroids(int r, int conn_size) { assert(is_alive.at(r)); const int thres = conn_size / 2; int c1 = -1, c2 = -1; auto rec_search = [&](auto &&self, int now, int prv) -> void { bool is_centroid = true; subtree_size.at(now) = 1; for (int nxt : to.at(now)) { if (nxt == prv or !is_alive.at(nxt)) continue; self(self, nxt, now); subtree_size.at(now) += subtree_size.at(nxt); if (subtree_size.at(nxt) > thres) is_centroid = false; } if (conn_size - subtree_size.at(now) > thres) is_centroid = false; if (is_centroid) (c1 < 0 ? c1 : c2) = now; }; rec_search(rec_search, r, -1); return {c1, c2}; } void run(int r, auto callback) { int conn_size = 0; auto rec = [&](auto &&self, int now, int prv) -> void { ++conn_size; is_alive.at(now) = 1; for (int nxt : to.at(now)) { if (nxt == prv) continue; self(self, nxt, now); } }; rec(rec, r, -1); decompose(r, conn_size, callback); } std::vector centroid_decomposition(int r) { std::vector res; run(r, [&](int v) { res.push_back(v); }); return res; } }; #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; using lint = long long; using pint = pair; using plint = pair; struct fast_ios { fast_ios(){ cin.tie(nullptr), ios::sync_with_stdio(false), cout << fixed << setprecision(20); }; } fast_ios_; #define ALL(x) (x).begin(), (x).end() #define FOR(i, begin, end) for(int i=(begin),i##_end_=(end);i=i##_begin_;i--) #define REP(i, n) FOR(i,0,n) #define IREP(i, n) IFOR(i,0,n) template bool chmax(T &m, const T q) { return m < q ? (m = q, true) : false; } template bool chmin(T &m, const T q) { return m > q ? (m = q, true) : false; } const std::vector> grid_dxs{{1, 0}, {-1, 0}, {0, 1}, {0, -1}}; int floor_lg(long long x) { return x <= 0 ? -1 : 63 - __builtin_clzll(x); } template T1 floor_div(T1 num, T2 den) { return (num > 0 ? num / den : -((-num + den - 1) / den)); } template std::pair operator+(const std::pair &l, const std::pair &r) { return std::make_pair(l.first + r.first, l.second + r.second); } template std::pair operator-(const std::pair &l, const std::pair &r) { return std::make_pair(l.first - r.first, l.second - r.second); } template std::vector sort_unique(std::vector vec) { sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end()); return vec; } template int arglb(const std::vector &v, const T &x) { return std::distance(v.begin(), std::lower_bound(v.begin(), v.end(), x)); } template int argub(const std::vector &v, const T &x) { return std::distance(v.begin(), std::upper_bound(v.begin(), v.end(), x)); } template IStream &operator>>(IStream &is, std::vector &vec) { for (auto &v : vec) is >> v; return is; } template OStream &operator<<(OStream &os, const std::vector &vec); template OStream &operator<<(OStream &os, const std::array &arr); template OStream &operator<<(OStream &os, const std::unordered_set &vec); template OStream &operator<<(OStream &os, const pair &pa); template OStream &operator<<(OStream &os, const std::deque &vec); template OStream &operator<<(OStream &os, const std::set &vec); template OStream &operator<<(OStream &os, const std::multiset &vec); template OStream &operator<<(OStream &os, const std::unordered_multiset &vec); template OStream &operator<<(OStream &os, const std::pair &pa); template OStream &operator<<(OStream &os, const std::map &mp); template OStream &operator<<(OStream &os, const std::unordered_map &mp); template OStream &operator<<(OStream &os, const std::tuple &tpl); template OStream &operator<<(OStream &os, const std::vector &vec) { os << '['; for (auto v : vec) os << v << ','; os << ']'; return os; } template OStream &operator<<(OStream &os, const std::array &arr) { os << '['; for (auto v : arr) os << v << ','; os << ']'; return os; } template std::istream &operator>>(std::istream &is, std::tuple &tpl) { std::apply([&is](auto &&... args) { ((is >> args), ...);}, tpl); return is; } template OStream &operator<<(OStream &os, const std::tuple &tpl) { os << '('; std::apply([&os](auto &&... args) { ((os << args << ','), ...);}, tpl); return os << ')'; } template OStream &operator<<(OStream &os, const std::unordered_set &vec) { os << '{'; for (auto v : vec) os << v << ','; os << '}'; return os; } template OStream &operator<<(OStream &os, const std::deque &vec) { os << "deq["; for (auto v : vec) os << v << ','; os << ']'; return os; } template OStream &operator<<(OStream &os, const std::set &vec) { os << '{'; for (auto v : vec) os << v << ','; os << '}'; return os; } template OStream &operator<<(OStream &os, const std::multiset &vec) { os << '{'; for (auto v : vec) os << v << ','; os << '}'; return os; } template OStream &operator<<(OStream &os, const std::unordered_multiset &vec) { os << '{'; for (auto v : vec) os << v << ','; os << '}'; return os; } template OStream &operator<<(OStream &os, const std::pair &pa) { return os << '(' << pa.first << ',' << pa.second << ')'; } template OStream &operator<<(OStream &os, const std::map &mp) { os << '{'; for (auto v : mp) os << v.first << "=>" << v.second << ','; os << '}'; return os; } template OStream &operator<<(OStream &os, const std::unordered_map &mp) { os << '{'; for (auto v : mp) os << v.first << "=>" << v.second << ','; os << '}'; return os; } #ifdef HITONANODE_LOCAL const string COLOR_RESET = "\033[0m", BRIGHT_GREEN = "\033[1;32m", BRIGHT_RED = "\033[1;31m", BRIGHT_CYAN = "\033[1;36m", NORMAL_CROSSED = "\033[0;9;37m", RED_BACKGROUND = "\033[1;41m", NORMAL_FAINT = "\033[0;2m"; #define dbg(x) std::cerr << BRIGHT_CYAN << #x << COLOR_RESET << " = " << (x) << NORMAL_FAINT << " (L" << __LINE__ << ") " << __FILE__ << COLOR_RESET << std::endl #define dbgif(cond, x) ((cond) ? std::cerr << BRIGHT_CYAN << #x << COLOR_RESET << " = " << (x) << NORMAL_FAINT << " (L" << __LINE__ << ") " << __FILE__ << COLOR_RESET << std::endl : std::cerr) #else #define dbg(x) ((void)0) #define dbgif(cond, x) ((void)0) #endif #include #include // 0-indexed BIT (binary indexed tree / Fenwick tree) (i : [0, len)) template struct BIT { int n; std::vector data; BIT(int len = 0) : n(len), data(len) {} void reset() { std::fill(data.begin(), data.end(), T(0)); } void add(int pos, T v) { // a[pos] += v pos++; while (pos > 0 and pos <= n) data[pos - 1] += v, pos += pos & -pos; } T sum(int k) const { // a[0] + ... + a[k - 1] T res = 0; while (k > 0) res += data[k - 1], k -= k & -k; return res; } T sum(int l, int r) const { return sum(r) - sum(l); } // a[l] + ... + a[r - 1] template friend OStream &operator<<(OStream &os, const BIT &bit) { T prv = 0; os << '['; for (int i = 1; i <= bit.n; i++) { T now = bit.sum(i); os << now - prv << ',', prv = now; } return os << ']'; } }; int main() { int N; cin >> N; CentroidDecomposition cd(N); for (int e = 0; e < N - 1; ++e) { int u, v; cin >> u >> v; --u, --v; cd.add_edge(u, v); } vector V(N); { std::string S; cin >> S; for (int i = 0; i < N; ++i) { V.at(i) = (S.at(i) - '0') * 2 - 1; } } std::vector is_alive(N, 1); lint ret = 0; std::vector cnt(N * 2 + 1); BIT bit(N * 2 + 1); for (int c : cd.centroid_decomposition(0)) { is_alive.at(c) = 0; int lo = 0, hi = 0; auto addbit = [&](int w) -> void { chmin(lo, w), chmax(hi, w); cnt.at(N + w)++; bit.add(N + w, 1); }; addbit(V.at(c)); if (V.at(c) > 0) ret++; for (int nxt : cd.to.at(c)) { if (!is_alive.at(nxt)) continue; std::vector ws; auto dfs = [&](auto &&self, int now, int prv, int w) -> void { ws.push_back(w); // addbit(w); ret += bit.sum(-w + N + 1, N * 2 + 1); for (int nxt : cd.to.at(now)) { if (nxt == prv or !is_alive.at(nxt)) continue; self(self, nxt, now, w + V.at(nxt)); } }; dfs(dfs, nxt, c, V.at(nxt)); for (int w : ws) { w += V.at(c); addbit(w); } } for (int i = lo; i <= hi; ++i) { int j = i + N; bit.add(j, -cnt.at(j)); cnt.at(j) = 0; } } cout << ret << '\n'; }