結果
| 問題 |
No.2892 Lime and Karin
|
| コンテスト | |
| ユーザー |
tottoripaper
|
| 提出日時 | 2024-09-17 00:36:31 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 4,098 bytes |
| コンパイル時間 | 2,649 ms |
| コンパイル使用メモリ | 206,692 KB |
| 最終ジャッジ日時 | 2025-02-24 09:03:42 |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 51 TLE * 1 |
ソースコード
#include <bits/stdc++.h>
using ll = std::int64_t;
void dfs(
int v,
int p,
int &counter,
std::vector<int> &in,
std::vector<int> &out,
std::vector<int> &diff,
const std::vector<std::vector<int>> &G,
const std::string &S
){
in[v] = counter;
counter += 1;
diff[in[v]] = S[v - 1] == '1' ? 1 : -1;
if(p != -1){
diff[in[v]] += diff[in[p]];
}
for(auto &w : G[v]){
if(w != p){
dfs(w, v, counter, in, out, diff, G, S);
}
}
out[v] = counter;
}
ll search(
int x, int l, int r,
const std::vector<int> &diff,
const std::vector<int> &diff_group,
const std::vector<int> &incr,
const int B
){
ll res = 0;
while(l < r && l % B > 0){
if(diff[l] + incr[l / B] >= x){
res += 1;
}
l += 1;
}
while(l < r && r % B > 0){
if(diff[r - 1] + incr[(r - 1) / B] >= x){
res += 1;
}
r -= 1;
}
while(l < r){
int last = std::min<int>(l + B, diff_group.size());
res +=
diff_group.begin() + last
- std::lower_bound(diff_group.begin() + l, diff_group.begin() + last, x - incr[l / B]);
l = last;
}
return res;
}
void add(
int x, int l, int r,
std::vector<int> &diff,
std::vector<int> &diff_group,
std::vector<int> &incr,
const int B
){
auto update_point = [&](int i){
int first = i / B * B, last = std::min<int>(first + B, diff_group.size());
if(x == -1){
int j =
std::lower_bound(diff_group.begin() + first, diff_group.begin() + last, diff[i])
- diff_group.begin();
diff_group[j] += x;
}else{
int j =
std::upper_bound(diff_group.begin() + first, diff_group.begin() + last, diff[i])
- diff_group.begin();
j -= 1;
diff_group[j] += x;
}
diff[i] += x;
};
while(l < r && l % B > 0){
update_point(l);
l += 1;
}
while(l < r && r % B > 0){
update_point(r - 1);
r -= 1;
}
l /= B;
r /= B;
while(l < r){
incr[l] += x;
l += 1;
}
}
ll dfs2(
int v,
int p,
const std::vector<int> &in,
const std::vector<int> &out,
std::vector<int> &diff,
std::vector<int> &diff_group,
std::vector<int> &incr,
const std::vector<std::vector<int>> &G,
const std::string &S,
const int B
){
ll res = 0;
res += search(1, 0, diff.size(), diff, diff_group, incr, B);
int xv = S[v - 1] == '1' ? 1 : -1;
for(auto &w : G[v]){
if(w != p){
int xw = S[w - 1] == '1' ? 1 : -1;
add(xw, 0, in[w], diff, diff_group, incr, B);
add(-xv, in[w], out[w], diff, diff_group, incr, B);
add(xw, out[w], diff.size(), diff, diff_group, incr, B);
res += dfs2(w, v, in, out, diff, diff_group, incr, G, S, B);
add(-xw, 0, in[w], diff, diff_group, incr, B);
add(xv, in[w], out[w], diff, diff_group, incr, B);
add(-xw, out[w], diff.size(), diff, diff_group, incr, B);
}
}
return res;
}
int main(){
std::cin.tie(nullptr);
std::ios::sync_with_stdio(false);
int N;
std::cin >> N;
std::vector<std::vector<int>> G(N + 1);
for(int i=0;i<N-1;i++){
int u, v;
std::cin >> u >> v;
G[u].emplace_back(v);
G[v].emplace_back(u);
}
std::string S;
std::cin >> S;
std::vector<int> in(N + 1), out(N + 1), diff(N, 0);
{
int counter = 0;
dfs(1, -1, counter, in, out, diff, G, S);
}
const int B = std::sqrt(N);
std::vector<int> diff_group(diff), incr((N + B - 1) / B, 0);
for(int i=0;i<N;i+=B){
std::sort(diff_group.begin() + i, diff_group.begin() + std::min(i + B, N));
}
ll res = dfs2(1, -1, in, out, diff, diff_group, incr, G, S, B);
int n = std::count(S.begin(), S.end(), '1');
res -= n;
res = res / 2 + n;
std::cout << res << std::endl;
}
tottoripaper