#[allow(unused_imports)] use std::cmp::*; #[allow(unused_imports)] use std::collections::*; use std::io::{Write, BufWriter}; // https://qiita.com/tanakh/items/0ba42c7ca36cd29d0ac8 macro_rules! input { ($($r:tt)*) => { let stdin = std::io::stdin(); let mut bytes = std::io::Read::bytes(std::io::BufReader::new(stdin.lock())); let mut next = move || -> String{ bytes.by_ref().map(|r|r.unwrap() as char) .skip_while(|c|c.is_whitespace()) .take_while(|c|!c.is_whitespace()) .collect() }; input_inner!{next, $($r)*} }; } macro_rules! input_inner { ($next:expr) => {}; ($next:expr,) => {}; ($next:expr, $var:ident : $t:tt $($r:tt)*) => { let $var = read_value!($next, $t); input_inner!{$next $($r)*} }; } macro_rules! read_value { ($next:expr, ( $($t:tt),* )) => { ($(read_value!($next, $t)),*) }; ($next:expr, [ $t:tt ; $len:expr ]) => { (0..$len).map(|_| read_value!($next, $t)).collect::>() }; ($next:expr, chars) => { read_value!($next, String).chars().collect::>() }; ($next:expr, usize1) => (read_value!($next, usize) - 1); ($next:expr, $t:ty) => ($next().parse::<$t>().expect("Parse error")); } trait Bisect { fn lower_bound(&self, val: &T) -> usize; fn upper_bound(&self, val: &T) -> usize; } impl Bisect for [T] { fn lower_bound(&self, val: &T) -> usize { let mut pass = self.len() + 1; let mut fail = 0; while pass - fail > 1 { let mid = (pass + fail) / 2; if &self[mid - 1] >= val { pass = mid; } else { fail = mid; } } pass - 1 } fn upper_bound(&self, val: &T) -> usize { let mut pass = self.len() + 1; let mut fail = 0; while pass - fail > 1 { let mid = (pass + fail) / 2; if &self[mid - 1] > val { pass = mid; } else { fail = mid; } } pass - 1 } } // Returns (root, children) // This functions uses O(n) stack space. // Complexity: O(n log n)-time, O(n)-space // Verified by: ABC291-Ex (https://atcoder.jp/contests/abc291/submissions/39303290) fn centroid_decompose(g: &[Vec]) -> (usize, Vec>) { fn find_subtree_sizes(g: &[Vec], v: usize, par: usize, dp: &mut [usize], vis: &[bool]) { let mut sum = 1; for &w in &g[v] { if par == w || vis[w] { continue; } find_subtree_sizes(g, w, v, dp, vis); sum += dp[w]; } dp[v] = sum; } fn centroid_decompose_inner(g: &[Vec], v: usize, par: usize, ch: &mut [Vec], dp: &mut [usize], vis: &mut [bool]) -> usize { let n = g.len(); find_subtree_sizes(g, v, n, dp, vis); let cent = { let sz = dp[v]; let find_centroid = |mut v: usize, mut par: usize| { loop { let mut has_majority = false; for &w in &g[v] { if par == w || vis[w] { continue; } if dp[w] > sz / 2 { par = v; v = w; has_majority = true; break; } } if !has_majority { return v; } } }; find_centroid(v, n) }; if par < n { ch[par].push(cent); } // v was selected as a centroid // and will be ignored in the following decomposition procedure vis[cent] = true; for &w in &g[cent] { if !vis[w] { centroid_decompose_inner(g, w, cent, ch, dp, vis); } } cent } let n = g.len(); let mut ch = vec![vec![]; n]; // This Vec is used across multiple calls to `centroid_decompose_inner` let mut dp = vec![0; n]; let mut vis = vec![false; n]; let root = centroid_decompose_inner(&g, 0, n, &mut ch, &mut dp, &mut vis); (root, ch) } fn main() { // In order to avoid potential stack overflow, spawn a new thread. let stack_size = 104_857_600; // 100 MB let thd = std::thread::Builder::new().stack_size(stack_size); thd.spawn(|| solve()).unwrap().join().unwrap(); } fn dfs(v: usize, par: usize, g: &[Vec], s: &[char], vis: &[bool], cur: i32, agg: &mut Vec) { let cur = if s[v] == '1' { cur + 1 } else { cur - 1 }; agg.push(cur); for &w in &g[v] { if par == w || vis[w] { continue; } dfs(w, v, g, s, vis, cur, agg); } } fn calc_comb(init: i32, mut comb: Vec) -> i64 { comb.sort_unstable(); let mut ret = 0; for &v in &comb { let idx = comb.lower_bound(&(-init - v + 1)); ret += (comb.len() - idx) as i64; } ret } // https://yukicoder.me/problems/no/2892 (4) fn solve() { input! { n: usize, uv: [(usize1, usize1); n - 1], s: chars, } let mut g = vec![vec![]; n]; for &(u, v) in &uv { g[u].push(v); g[v].push(u); } let (root, ch) = centroid_decompose(&g); let mut vis = vec![false; n]; let mut que = vec![root]; let mut tot = 0; while let Some(v) = que.pop() { vis[v] = true; for &w in &ch[v] { que.push(w); } // search let mut agg = vec![]; let init = if s[v] == '1' { 1 } else { -1 }; for &c in &ch[v] { let mut me = vec![]; dfs(c, v, &g, &s, &vis, 0, &mut me); agg.extend_from_slice(&me); tot -= calc_comb(init, me); } for &a in &agg { if init + a > 0 { tot += 2; } } tot += calc_comb(init, agg); if init == 1 { tot += 2; } } println!("{}", tot / 2); }