import sys from itertools import permutations from heapq import heappop,heappush from collections import deque import random import bisect input = lambda :sys.stdin.readline().rstrip() mi = lambda :map(int,input().split()) li = lambda :list(mi()) def solve(N,edge,S): centroid_done = [False] * N tmp_parent = [None] * N tmp_sz = [None] * N def find_centroid(root): deq = deque([root]) topo = [] while deq: v = deq.popleft() topo.append(v) for nv in edge[v]: if centroid_done[nv]: continue if nv == tmp_parent[v]: continue tmp_parent[nv] = v deq.append(nv) tmp_n = len(topo) centroid = -1 for v in topo[::-1]: tmp_sz[v] = 1 centroid_flg = True for nv in edge[v]: if centroid_done[nv]: continue if nv == tmp_parent[v]: continue if tmp_sz[nv]*2 > tmp_n: centroid_flg = False tmp_sz[v] += tmp_sz[nv] if 2*(tmp_n-tmp_sz[v]) > tmp_n: centroid_flg = False if centroid_flg: centroid = v for v in topo: tmp_sz[v] = None tmp_parent[v] = None return centroid def f(n,A,x): """ A[i]+A[j]+x > 0 を満たすi,jの数を求める -n <= A[i] <= n が保証される """ tmp_cnt = [0] * (2*n+1) for a in A: tmp_cnt[a+n] += 1 for i in range(2*n)[::-1]: tmp_cnt[i] += tmp_cnt[i+1] res = 0 for a in A: lower = max((-a-x+1) + n,0) if lower <= 2*n: res += tmp_cnt[lower] if a+a+x > 0: res -= 1 return res//2 tmp_dep = [None] * N tmp_centroid_child_parent = [None] * N tmp_centroid_child_group = [[] for v in range(N)] def calc_sub(centroid): deq = deque([centroid]) tmp_dep[centroid] = 0 topo = [] while deq: v = deq.popleft() topo.append(v) for nv in edge[v]: if centroid_done[nv]: continue if nv == tmp_parent[v]: continue tmp_parent[nv] = v if v == centroid: tmp_centroid_child_parent[nv] = nv else: tmp_centroid_child_parent[nv] = tmp_centroid_child_parent[v] tmp_centroid_child_group[tmp_centroid_child_parent[nv]].append(nv) tmp_dep[nv] = tmp_dep[v] + 2 * S[nv] - 1 deq.append(nv) tmp_n = len(topo) A = [tmp_dep[v] for v in topo if v!=centroid] res = f(tmp_n,A,2 * S[centroid] - 1) for v in topo: if v != centroid and tmp_dep[v] + 2 * S[centroid] - 1 > 0: res += 1 for centroid_child in edge[centroid]: if centroid_done[centroid_child]: continue tmp_n = len(tmp_centroid_child_group[centroid_child]) A = [tmp_dep[v] for v in tmp_centroid_child_group[centroid_child]] res -= f(tmp_n,A,2 * S[centroid] - 1) for v in topo: tmp_parent[v] = None tmp_sz[v] = None tmp_dep[v] = None tmp_centroid_child_group[v] = [] tmp_centroid_child_parent[v] = None return res def centroid_decomp(): res = 0 root_deq = deque([0]) while root_deq: root = root_deq.popleft() centroid = find_centroid(root) res += calc_sub(centroid) #print(centroid,calc_sub(centroid)) centroid_done[centroid] = True for centroid_child in edge[centroid]: if not centroid_done[centroid_child]: root_deq.append(centroid_child) return res return centroid_decomp() + sum(S) N = int(input()) edge = [[] for v in range(N)] for _ in range(N-1): u,v = mi() edge[u-1].append(v-1) edge[v-1].append(u-1) S = [int(c) for c in input()] print(solve(N,edge,S))