#ifdef NACHIA #define _GLIBCXX_DEBUG #else #define NDEBUG #endif #include #include #include #include using i64 = long long; using u64 = unsigned long long; #define rep(i,n) for(int i=0; i void chmin(A& l, const A& r){ if(r < l) l = r; } template void chmax(A& l, const A& r){ if(l < r) l = r; } #include using Modint = atcoder::static_modint<998244353>; using namespace std; void testcase(){ int N; cin >> N; vector> adj(N); rep(i,N-1){ int u,v; cin >> u >> v; u--; v--; adj[u].push_back(v); adj[v].push_back(u); } vector aw(N); { string s; cin >> s; rep(i,N) aw[i] = (s[i] == '1' ? 1 : -1); } vector parent(N, -1); vector Z(N, 1); { vector bfs; bfs.push_back(0); rep(i,N){ int v = bfs[i]; for(int w : adj[v]) if(parent[v] != w){ parent[w] = v; bfs.push_back(w); } } for(int i=N-1; i>=1; i--) Z[parent[bfs[i]]] += Z[bfs[i]]; } //cout << "##" << endl; i64 ans = 0; auto flushTable = [&](int off, const vector& c, int root) -> i64 { auto cs = vector(c.size()+1); rep(i,c.size()) cs[i+1] += cs[i] + c[i]; i64 cc = i64(cs.back()) * i64(cs.back()); for(i64 i=0; i= 0) cc -= i64(c[i]) * i64(cs[min(cs.size() - 1, j + 1)]); } return cc; }; //for(auto p : parent) cout << p << " "; //cout << endl; //for(auto p : Z) cout << p << " "; //cout << endl; auto cd = [&](auto& cd, int v) -> void { while(true){ int nx = -1; for(int w : adj[v]) if(Z[w] * 2 > Z[v]){ Z[v] -= Z[w]; Z[w] += Z[v]; parent[v] = w; parent[w] = -1; nx = w; break; } if(nx >= 0) v = nx; else break; } //cout << "v = " << v << endl; int off = Z[v]; vector cnt(Z[v] * 2 + 1); cnt[off] += 1; int rootaw = aw[v]; Z[v] = 0; for(int w : adj[v]) if(Z[w] != 0){ //cout << "w = " << w << endl; parent[w] = -1; int offw = Z[w]; vector cntw(Z[w] * 2 + 1); vector> bfs; bfs.push_back({ w, 0 }); rep(q,bfs.size()){ auto [x,c] = bfs[q]; //cout << "x = " << x << " , c = " << c << endl; c += aw[x]; cntw[c + offw] += 1; cnt[c + off] += 1; for(int y : adj[x]) if(Z[y] != 0 && parent[x] != y){ bfs.push_back({ y, c }); } } //cout << "v = " << v << " , w = " << w << endl; //for(auto a : cntw) cout << a << " "; //cout << endl; ans -= flushTable(offw, cntw, aw[v]); cd(cd, w); } // cout << "v = " << v << endl; // for(auto a : cnt) cout << a << " "; // cout << endl; ans += flushTable(off, cnt, aw[v]); }; rep(i,N) if(aw[i] > 0) ans += 1; cd(cd, 0); ans /= 2; cout << ans << endl; } int main(){ ios::sync_with_stdio(false); cin.tie(nullptr); testcase(); return 0; }