from bisect import bisect_left N = int(input()) G = [[] for _ in range(N)] for _ in range(N - 1): u, v = [int(x) - 1 for x in input().split()] G[u].append(v) G[v].append(u) S = [1 if x == '1' else -1 for x in input()] used = [False] * N size = [0] * N def centroid(x): stk = [(x, None, 0)] while stk: u, p, t = stk.pop() if t == 0: stk.append((u, p, 1)) for v in G[u]: if not used[v] and v != p: stk.append((v, u, 0)) else: size[u] = 1 + sum(size[v] for v in G[u] if not used[v] and v != p) p = None while True: y = next((y for y in G[x] if not used[y] and y != p and size[y] * 2 > size[x]), None) if y is None: break x, p = y, x return x def totals(x, p): stk = [(x, p, S[x])] res = [] while stk: u, p, d = stk.pop() res.append(d) for v in G[u]: if not used[v] and v != p: stk.append((v, u, d + S[v])) return res stk = [0] ans = 0 while stk: u = centroid(stk.pop()) ds = [] val = 0 if S[u] == 1: ans += 1 for v in G[u]: if used[v]: continue stk.append(v) d = totals(v, u) ans += sum(x + S[u] > 0 for x in d) d.sort() # d[i] + d[j] + S[u] > 0 val -= sum(len(d) - bisect_left(d, -x - S[u] + 1) for x in d) ds.extend(d) ds.sort() val += sum(len(ds) - bisect_left(ds, -x - S[u] + 1) for x in ds) ans += val // 2 used[u] = True print(ans)