結果
問題 |
No.2892 Lime and Karin
|
ユーザー |
![]() |
提出日時 | 2024-09-17 17:06:06 |
言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.87.0) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,982 bytes |
コンパイル時間 | 8,902 ms |
コンパイル使用メモリ | 279,464 KB |
最終ジャッジ日時 | 2025-02-24 09:08:26 |
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 26 TLE * 10 MLE * 1 -- * 15 |
ソースコード
#include <bits/stdc++.h> using ll = std::int64_t; void point_update( int x, int i, std::vector<int> &diff, std::vector<std::map<int, int>> &diff_group, std::vector<int> &incr, int &count, const int B ){ int ig = i / B; if(diff[i] + incr[ig] == 1 && x == -1){ count -= 1; }else if(diff[i] + incr[ig] == 0 && x == 1){ count += 1; } diff_group[ig][diff[i]] -= 1; diff[i] += x; diff_group[ig][diff[i]] += 1; } void add( int x, int l, int r, std::vector<int> &diff, std::vector<std::map<int, int>> &diff_group, std::vector<int> &incr, int &count, const int B ){ while(l < r && l % B > 0){ point_update(x, l, diff, diff_group, incr, count, B); l += 1; } while(l < r && r % B > 0){ point_update(x, r - 1, diff, diff_group, incr, count, B); r -= 1; } l /= B; r /= B; while(l < r){ if(x == 1){ count += diff_group[l][-incr[l]]; }else{ count -= diff_group[l][1 - incr[l]]; } incr[l] += x; l += 1; } } int main(){ std::cin.tie(nullptr); std::ios::sync_with_stdio(false); int N; std::cin >> N; std::vector<std::vector<int>> G(N + 1); for(int i=0;i<N-1;i++){ int u, v; std::cin >> u >> v; G[u].emplace_back(v); G[v].emplace_back(u); } std::string S; std::cin >> S; std::vector<int> in(N + 1), out(N + 1), diff(N, 0); { int counter = 0; std::stack<std::tuple<int, int, int>> stack; stack.emplace(0, 1, -1); while(!stack.empty()){ auto [t, v, p] = stack.top(); stack.pop(); if(t == 0){ in[v] = counter; counter += 1; diff[in[v]] = S[v - 1] == '1' ? 1 : -1; if(p != -1){ diff[in[v]] += diff[in[p]]; } stack.emplace(1, v, p); for(auto &w : G[v]){ if(w != p){ stack.emplace(0, w, v); } } }else{ out[v] = counter; } } } const int B = std::sqrt(N); int count = 0; std::vector<std::map<int, int>> diff_group((N + B - 1) / B); std::vector<int> incr((N + B - 1) / B, 0); for(int i=0;i<N;i++){ diff_group[i / B][diff[i]] += 1; if(diff[i] >= 1){ count += 1; } } ll res = 0; { std::stack<std::tuple<int, int, int>> stack; stack.emplace(0, 1, -1); while(!stack.empty()){ auto [t, v, p] = stack.top(); stack.pop(); if(t == 0){ stack.emplace(1, v, p); if(p != -1){ int xv = S[v - 1] == '1' ? 1 : -1; int xp = S[p - 1] == '1' ? 1 : -1; add(xv, 0, in[v], diff, diff_group, incr, count, B); add(-xp, in[v], out[v], diff, diff_group, incr, count, B); add(xv, out[v], diff.size(), diff, diff_group, incr, count, B); } res += count; for(auto &w : G[v]){ if(w != p){ stack.emplace(0, w, v); } } }else{ if(p != -1){ int xv = S[v - 1] == '1' ? 1 : -1; int xp = S[p - 1] == '1' ? 1 : -1; add(-xv, 0, in[v], diff, diff_group, incr, count, B); add(xp, in[v], out[v], diff, diff_group, incr, count, B); add(-xv, out[v], diff.size(), diff, diff_group, incr, count, B); } } } } int n = std::count(S.begin(), S.end(), '1'); res -= n; res = res / 2 + n; std::cout << res << std::endl; }