from collections import Counter import sys input = sys.stdin.readline sys.setrecursionlimit(10 ** 7) N = int(input()) S = input() cnt = Counter(S) G = [[] for _ in range(N)] for _ in range(N - 1): a, b = map(int, input().split()) a -= 1 b -= 1 G[a].append(b) G[b].append(a) topo = [] par = [-1] * N que = [0] while que: s = que.pop() topo.append(s) for t in G[s]: if t == par[s]: continue par[t] = s que.append(t) ans = 0 C = [0] * N W = [0] * N for i in topo[::-1][:-1]: if S[i] == "c": C[i] += 1 else: W[i] += 1 p = par[i] if S[p] == "w": w = cnt["w"] - W[i] - 1 ans += C[i] * w c = cnt["c"] - C[i] ans += c * W[i] C[p] += C[i] W[p] += W[i] print(ans)