//line 1 "answer.cpp" #include #include #include using namespace std; using ll = long long; using mint = atcoder::modint998244353; vector multiply(vector a, vector b) { vector a_mod(a.begin(), a.end()); vector b_mod(a.begin(), a.end()); auto result_mod = atcoder::convolution(a, b); vector result(result_mod.begin(), result_mod.end()); return result; } int main() { ll n; cin >> n; vector g(n, vector(0)); for (size_t i = 0; i < n - 1; i++) { ll u, v; cin >> u >> v; --u; --v; g[u].push_back(v); g[v].push_back(u); } string si; cin >> si; vector s(n); for (size_t i = 0; i < n; i++) s[i] = (si[i] == '0' ? -1 : 1); ll ans = 0; vector removed(n, false); vector sz(n, 0); auto centroid_decomposition = [&](auto self, ll x) -> void { auto calc_total_size = [&](auto self, ll u, ll par) -> ll { ll result = 1; for (auto v : g[u]) if (v != par && !removed[v]) result += self(self, v, u); return result; }; ll total_size = calc_total_size(calc_total_size, x, -1); ll m = (total_size + 1) / 2; auto find_centroid = [&](auto self, ll u, ll par) -> pair { ll c = -1; ll mn = n + 1; sz[u] = 1; for (auto v : g[u]) { if (v == par || removed[v]) continue; auto [cn, mnn] = self(self, v, u); if (mnn < mn) { mn = mnn; c = cn; } sz[u] += sz[v]; } if (sz[u] >= m && sz[u] < mn) { mn = sz[u]; c = u; } return {c, mn}; }; auto [c, _] = find_centroid(find_centroid, x, -1); vector> f(0); vector getas; for (auto v : g[c]) { if (removed[v]) continue; ll mx = 0; auto calc_geta = [&](auto self, ll u, ll par, ll x) -> ll { ll result = x; mx = max(mx, x); for (auto v : g[u]) { if (v == par || removed[v]) continue; result = min(result, self(self, v, u, x + s[v])); } return result; }; ll geta = max(0LL, -calc_geta(calc_geta, v, c, s[v])); f.push_back(vector(mx + geta + 1, 0)); getas.push_back(geta); auto calc_f = [&](auto self, ll u, ll par, ll x) -> void { f[f.size() - 1][x + geta] += 1; for (auto v : g[u]) { if (v == par || removed[v]) continue; self(self, v, u, x + s[v]); } }; calc_f(calc_f, v, c, s[v]); } ll gmax = 0; for (auto gi : getas) gmax = max(gmax, gi); vector total(gmax + sz[x], 0); for (size_t i = 0; i < f.size(); i++) { for (size_t j = 0; j < f[i].size(); j++) { total[j - getas[i] + gmax] += f[i][j]; } } for (size_t i = gmax + 1 - s[c]; i < total.size(); i++) ans += total[i]; total = multiply(total, total); gmax *= 2; for (size_t i = 0; i < f.size(); i++) { f[i] = multiply(f[i], f[i]); getas[i] *= 2; for (size_t j = 0; j < f[i].size(); j++) { if (f[i][j] != 0) total[j - getas[i] + gmax] -= f[i][j]; } } for (size_t i = gmax + 1 - s[c]; i < total.size(); i++) ans += total[i] / 2; removed[c] = true; for (auto v : g[c]) { if (!removed[v]) self(self, v); } return; }; centroid_decomposition(centroid_decomposition, 0); ll add = 0; for (size_t i = 0; i < n; i++) add += (si[i] - '0'); cout << ans + add << endl; return 0; }