N = int(input()) G = [[] for _ in range(N)] for _ in range(N - 1): u, v = [int(x) - 1 for x in input().split()] G[u].append(v) G[v].append(u) S = [1 if x == '1' else -1 for x in input()] used = [False] * N size = [0] * N def centroid(x): stk = [(x, None, 0)] while stk: u, p, t = stk.pop() if t == 0: stk.append((u, p, 1)) stk.extend((v, u, 0) for v in G[u] if not used[v] and v != p) else: size[u] = 1 + sum(size[v] for v in G[u] if not used[v] and v != p) p = None s = size[x] while True: y = next((y for y in G[x] if not used[y] and y != p and size[y] * 2 > s), None) if y is None: break x, p = y, x return x def paths(x, p): stk = [(x, p, S[x])] res = [] while stk: u, p, d = stk.pop() res.append(d) stk.extend((v, u, d + S[v]) for v in G[u] if not used[v] and v != p) return res def pairs(xs, y): xs.sort() res = 0 i = len(xs) for x in xs: while i >= 1 and x + xs[i - 1] + y > 0: i -= 1 res += len(xs) - i return res stk = [0] ans = 0 while stk: u = centroid(stk.pop()) if S[u] == 1: ans += 1 vs = [v for v in G[u] if not used[v]] stk.extend(vs) ds = [paths(v, u) for v in vs] ans += sum(sum(x + S[u] > 0 for x in d) for d in ds) ans += (pairs(sum(ds, start=[]), S[u]) - sum(pairs(d, S[u]) for d in ds)) // 2 used[u] = True print(ans)