結果

問題 No.2892 Lime and Karin
ユーザー FUN_MorikuboFUN_Morikubo
提出日時 2024-09-13 23:18:39
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 225 ms / 8,000 ms
コード長 3,418 bytes
コンパイル時間 8,060 ms
コンパイル使用メモリ 336,608 KB
実行使用メモリ 28,060 KB
最終ジャッジ日時 2024-09-13 23:18:55
合計ジャッジ時間 15,381 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 52
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
template <class T> using vec = vector<T>;
template <class T> using vv = vec<vec<T>>;
template <class T> using vvv = vec<vv<T>>;
using vl = vec<ll>;
using vvl = vv<ll>;
#define rep(i, n) for (ll i = 0; i < ll(n); i++)
#define eb emplace_back

#include <atcoder/all>
using namespace atcoder;

ll op(ll a, ll b) { return a + b; }
ll e() { return 0; }
void solve() {
  ll n;
  cin >> n;
  vvl g(n);
  rep(i, n - 1) {
    ll u, v;
    cin >> u >> v;
    u--, v--;
    g[u].eb(v);
    g[v].eb(u);
  }
  string s;
  cin >> s;

  vl bfs;
  bfs.reserve(n);
  vl par(n, -1), sz(n, 1), depth(n, 0);
  vvl child(n);
  {
    queue<ll> q({0});
    while (q.size() > 0) {
      ll now = q.front();
      q.pop();
      bfs.eb(now);
      for (auto to : g[now]) {
        if (par[now] == to) continue;
        par[to] = now;
        q.push(to);
      }
    }
  }
  for (auto now : bfs | views::reverse) {
    for (auto to : g[now]) {
      if (par[now] == to) continue;
      sz[now] += sz[to];
      depth[now] = max(depth[now], depth[to] + 1);
      child[now].eb(to);
      if (sz[child[now][0]] < sz[child[now].back()])
        swap(child[now][0], child[now].back());
    }
  }

  // rep(now, n) cout << sz[now] << " ";
  // cout << "\n";

  using SEG = segtree<ll, op, e>;
  ll ans = 0;
  auto dfs = [&](auto f, ll now) -> tuple<ll, SEG, ll> {
    ll cur = now;
    ll seg_sz = 2 * sz[now] + 3;
    SEG seg(seg_sz);
    ll diff = sz[now] + 1;
    while (child[cur].size() > 0) cur = child[cur][0];

    while (true) {
      seg.set(diff, seg.get(diff) + 1);
      if (s[cur] == '0')
        diff--;
      else
        diff++;
      // if (seg.prod(0, diff) > 0)
      //   cout << cur << ", + " << seg.prod(0, diff) << "\n";
      ans += seg.prod(0, diff);

      for (ll i = 1; i < child[cur].size(); i++) {
        ll to = child[cur][i];
        auto [c_diff, c_seg, c_seg_sz] = f(f, to);
        for (ll j = 0; j < c_seg_sz; j++) {
          ll c_idx = j - c_diff;
          ll idx = diff - c_idx;
          if (idx < 0 || seg_sz < idx) continue;
          ll tmp = seg.prod(0, idx) * c_seg.get(j);
          ans += tmp;
          // if (tmp > 0) cout << cur << ", " << to << ", + " << tmp << "\n";
        }
        for (ll j = 0; j < c_seg_sz; j++) {
          ll c_idx = j - c_diff;
          ll idx = c_idx + diff;
          if (s[cur] == '0') idx++;
          else idx--;
          if (idx < 0 || seg_sz <= idx) continue;
          ll tmp = c_seg.get(j);
          if (tmp == 0) continue;
          tmp += seg.get(idx);
          seg.set(idx, tmp);
        }
        // cout << "cur, diff: " << cur << " " << diff << "\n";
        // rep(i, diff) cout << seg.get(i) << " |"[i + 1 == diff];
        // cout << seg.get(diff) << "|";
        // for (ll i = diff + 1; i < seg_sz; i++) cout << seg.get(i) << " ";
        // cout << "\n";
      }
      // if (child[cur].size() <= 1) {
      //   cout << "cur, diff: " << cur << " " << diff << "\n";
      //   rep(i, diff) cout << seg.get(i) << " |"[i + 1 == diff];
      //   cout << seg.get(diff) << "|";
      //   for (ll i = diff + 1; i < seg_sz; i++) cout << seg.get(i) << " ";
      //   cout << "\n";
      // }
      if (cur == now) return {diff, seg, seg_sz};
      cur = par[cur];
    }
  };
  dfs(dfs, 0);
  cout << ans << "\n";
}

int main() {
  ll t = 1;
  // cin >> t;
  rep(i, t) solve();
}
0