結果
問題 | No.2892 Lime and Karin |
ユーザー |
|
提出日時 | 2024-09-13 23:16:24 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
RE
|
実行時間 | - |
コード長 | 3,315 bytes |
コンパイル時間 | 7,985 ms |
コンパイル使用メモリ | 337,172 KB |
実行使用メモリ | 28,172 KB |
最終ジャッジ日時 | 2024-09-13 23:16:52 |
合計ジャッジ時間 | 20,491 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 19 RE * 33 |
ソースコード
#include <bits/stdc++.h>using namespace std;using ll = long long;template <class T> using vec = vector<T>;template <class T> using vv = vec<vec<T>>;template <class T> using vvv = vec<vv<T>>;using vl = vec<ll>;using vvl = vv<ll>;#define rep(i, n) for (ll i = 0; i < ll(n); i++)#define eb emplace_back#include <atcoder/all>using namespace atcoder;ll op(ll a, ll b) { return a + b; }ll e() { return 0; }void solve() {ll n;cin >> n;vvl g(n);rep(i, n - 1) {ll u, v;cin >> u >> v;u--, v--;g[u].eb(v);g[v].eb(u);}string s;cin >> s;vl bfs;bfs.reserve(n);vl par(n, -1), sz(n, 1), depth(n, 0);vvl child(n);{queue<ll> q({0});while (q.size() > 0) {ll now = q.front();q.pop();bfs.eb(now);for (auto to : g[now]) {if (par[now] == to) continue;par[to] = now;q.push(to);}}}for (auto now : bfs | views::reverse) {for (auto to : g[now]) {if (par[now] == to) continue;sz[now] += sz[to];depth[now] = max(depth[now], depth[to] + 1);child[now].eb(to);if (sz[child[now][0]] < sz[child[now].back()])swap(child[now][0], child[now].back());}}// rep(now, n) cout << sz[now] << " ";// cout << "\n";using SEG = segtree<ll, op, e>;ll ans = 0;auto dfs = [&](auto f, ll now) -> tuple<ll, SEG, ll> {ll cur = now;ll seg_sz = 2 * sz[now] + 1;SEG seg(seg_sz);ll diff = sz[now];while (child[cur].size() > 0) cur = child[cur][0];while (true) {seg.set(diff, seg.get(diff) + 1);if (s[cur] == '0')diff--;elsediff++;// if (seg.prod(0, diff) > 0)// cout << cur << ", + " << seg.prod(0, diff) << "\n";ans += seg.prod(0, diff);for (ll i = 1; i < child[cur].size(); i++) {ll to = child[cur][i];auto [c_diff, c_seg, c_seg_sz] = f(f, to);for (ll j = 0; j < c_seg_sz; j++) {ll c_idx = j - c_diff;ll idx = diff - c_idx;ll tmp = seg.prod(0, idx) * c_seg.get(j);ans += tmp;// if (tmp > 0) cout << cur << ", " << to << ", + " << tmp << "\n";}for (ll j = 0; j < c_seg_sz; j++) {ll c_idx = j - c_diff;ll idx = c_idx + diff;if (s[cur] == '0') idx++;else idx--;ll tmp = c_seg.get(j);if (tmp == 0) continue;tmp += seg.get(idx);seg.set(idx, tmp);}// cout << "cur, diff: " << cur << " " << diff << "\n";// rep(i, diff) cout << seg.get(i) << " |"[i + 1 == diff];// cout << seg.get(diff) << "|";// for (ll i = diff + 1; i < seg_sz; i++) cout << seg.get(i) << " ";// cout << "\n";}// if (child[cur].size() <= 1) {// cout << "cur, diff: " << cur << " " << diff << "\n";// rep(i, diff) cout << seg.get(i) << " |"[i + 1 == diff];// cout << seg.get(diff) << "|";// for (ll i = diff + 1; i < seg_sz; i++) cout << seg.get(i) << " ";// cout << "\n";// }if (cur == now) return {diff, seg, seg_sz};cur = par[cur];}};dfs(dfs, 0);cout << ans << "\n";}int main() {ll t = 1;// cin >> t;rep(i, t) solve();}