#include #include #include #include using namespace std; int N; string S; vectorG[1<<17]; int sz[1<<17]; bool vis[1<<17]; void dfs_sz(int u,int p) { sz[u]=1; for(int v:G[u])if(!vis[v]&&v!=p) { dfs_sz(v,u); sz[u]+=sz[v]; } } vectorV; void dfs1(int u,int p,int cur) { cur+=S[u]=='1'?1:-1; V.push_back(cur); for(int v:G[u])if(!vis[v]&&v!=p)dfs1(v,u,cur); } long ans; long solve(vector&V,int add) { sort(V.begin(),V.end()); long ret=0; int j=V.size(); for(int i=0;i0&&V[i]+V[j-1]+add>0)j--; ret+=V.size()-max(i,j); } return ret; } void dfs(int u) { dfs_sz(u,-1); int root=u; { int p=-1; while(true) { bool fn=false; for(int v:G[root])if(!vis[v]&&v!=p&&sz[v]>sz[u]/2) { p=root; root=v; fn=true; break; } if(!fn)break; } } int add=S[root]=='1'?1:-1; vectorAll; All.reserve(sz[u]); All.push_back(0); V.reserve(sz[u]/2); for(int v:G[root])if(!vis[v]) { V.clear(); dfs1(v,root,0); ans-=solve(V,add); for(int w:V)All.push_back(w); } ans+=solve(All,add); vis[root]=true; for(int v:G[root])if(!vis[v])dfs(v); } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin>>N; for(int i=1;i>u>>v; u--,v--; G[u].push_back(v); G[v].push_back(u); } cin>>S; dfs(0); cout<