#include #define FL(i,a,b) for(ll i=(a);i<=(b);i++) #define FR(i,a,b) for(ll i=(a);i>=(b);i--) #define ll long long using namespace std; const ll MAXN = 1e5 + 10; ll n,sumc=0,sumw=0,ans=0; ll f[MAXN],g[MAXN]; vectorG[MAXN]; char s[MAXN]; void dfs(ll u,ll fa){ if(s[u]=='w') f[u]++; else g[u]++; for(ll v:G[u]){ if(v==fa) continue; dfs(v,u); f[u]+=f[v]; g[u]+=g[v]; } if(s[u]=='w'){ for(ll v:G[u]){ if(v==fa) continue; ans+=g[v]*(sumw-f[v]-1); } ans+=(sumc-g[u])*(f[u]-1); } } signed main(){ scanf("%lld",&n); scanf("%s",s+1); FL(i,1,n-1){ ll u,v; scanf("%lld%lld",&u,&v); G[u].push_back(v); G[v].push_back(u); } FL(i,1,n){ if(s[i]=='w') sumw++; else sumc++; } dfs(1,0); printf("%lld\n",ans); }