from collections import Counter import sys input = sys.stdin.readline sys.setrecursionlimit(10 ** 7) N = int(input()) S = input() cnt = Counter(S) G = [[] for _ in range(N)] for _ in range(N - 1): a, b = map(int, input().split()) a -= 1 b -= 1 G[a].append(b) G[b].append(a) topo = [] par = [-1] * N que = [0] while que: s = que.pop() topo.append(s) for t in G[s]: if t == par[s]: continue G[t].remove(s) par[t] = s que.append(t) ans = 0 C = [0] * N W = [0] * N for i in topo[::-1][:-1]: p = par[i] if S[i] == "w": c = cnt["c"] - C[i] w = W[i] ans += c * w if S[p] == "w": c = C[i] + (S[i] == "c") w = cnt["w"] - W[i] - 1 - (S[i] == "w") ans += c * w C[p] += C[i] + (S[i] == "c") W[p] += W[i] + (S[i] == "w") print(ans)