結果

問題 No.3193 Submit Your Solution
ユーザー Moss_Local
提出日時 2025-06-27 22:18:00
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 3,865 bytes
コンパイル時間 4,721 ms
コンパイル使用メモリ 425,224 KB
実行使用メモリ 42,720 KB
最終ジャッジ日時 2025-06-27 22:18:18
合計ジャッジ時間 17,301 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <x86intrin.h>
#pragma GCC optimize("O3")
#pragma GCC target("avx2")

using namespace std;

using ll = long long;
using u32 = unsigned int;
using u64 = unsigned long long;
using i128 = __int128;
using u128 = unsigned __int128;
using f128 = __float128;

int n;
vector<vector<int>> g1;
vector<vector<pair<int, int>>> g2;
vector<int> tin, tout, euler;
vector<int> dist1n, dist2n;
int timer = 0;

void dt(int u, int p) {
  tin[u] = timer;
  euler[timer++] = u;
  for (int v : g1[u]) {
    if (v == p) continue;
    dt(v, u);
  }
  tout[u] = timer;
}

void dfs1(int u, int p, int d) {
  dist1n[u] = d;
  for (int v : g1[u]) {
    if (v == p) continue;
    dfs1(v, u, d + 1);
  }
}

void dfs2(int u, int p, int d) {
  dist2n[u] = d;
  for (auto &ed : g2[u]) {
    int v = ed.first, w = ed.second;
    if (v == p) continue;
    dfs2(v, u, d + w);
  }
}

static constexpr int LANES = 8;

void add_constf(int32_t *buf, int len, int32_t val) {
  __m256i v = _mm256_set1_epi32(val);
  int i = 0;
  for (; i + LANES <= len; i += LANES) {
    __m256i seg = _mm256_loadu_si256((__m256i *)(buf + i));
    seg = _mm256_add_epi32(seg, v);
    _mm256_storeu_si256((__m256i *)(buf + i), seg);
  }
  for (; i < len; ++i) buf[i] += val;
}

void range_add(int32_t *buf, int l, int r, int32_t val) {
  __m256i v = _mm256_set1_epi32(val);
  int i = l;
  for (; i + LANES <= r; i += LANES) {
    __m256i seg = _mm256_loadu_si256((__m256i *)(buf + i));
    seg = _mm256_add_epi32(seg, v);
    _mm256_storeu_si256((__m256i *)(buf + i), seg);
  }
  for (; i < r; ++i) buf[i] += val;
}

u128 dot_prod(const int32_t *a, const int32_t *b, int len) {
  u128 sum = 0;
  int i = 0;
  alignas(32) int32_t tmp[LANES];
  for (; i + LANES <= len; i += LANES) {
    __m256i va = _mm256_loadu_si256((__m256i *)(a + i));
    __m256i vb = _mm256_loadu_si256((__m256i *)(b + i));
    __m256i prod = _mm256_mullo_epi32(va, vb);
    _mm256_store_si256((__m256i *)tmp, prod);
    for (int j = 0; j < LANES; ++j) {
      sum += (int64_t)tmp[j];
    }
  }
  for (; i < len; ++i) {
    sum += (int64_t)a[i] * b[i];
  }
  return sum;
}

// // no simd
// void add_constf(int32_t *buf, int len, int32_t val) {
//   for (int i = 0; i < len; ++i) {
//     buf[i] += val;
//   }
// }

// void range_add(int32_t *buf, int l, int r, int32_t val) {
//   for (int i = l; i < r; ++i) {
//     buf[i] += val;
//   }
// }
// u128 dot_prod(const int32_t *a, const int32_t *b, int len) {
//   u128 sum = 0;
//   for (int i = 0; i < len; ++i) {
//     sum += (u128)a[i] * b[i];
//   }
//   return sum;
// }

void reroot(int u, int p, vector<int32_t> &dist1, vector<int32_t> &dist2,
            u128 &ans) {
  dfs2(u, -1, 0);
  for (int i = 0; i < n; ++i) {
    dist2[i] = dist2n[euler[i]];
  }
  ans += dot_prod(dist1.data(), dist2.data(), n);

  for (int v : g1[u]) {
    if (v == p) continue;
    add_constf(dist1.data(), n, +1);
    range_add(dist1.data(), tin[v], tout[v], -2);

    reroot(v, u, dist1, dist2, ans);

    range_add(dist1.data(), tin[v], tout[v], +2);
    add_constf(dist1.data(), n, -1);
  }
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  cin >> n;
  g1.assign(n, {});
  g2.assign(n, {});
  int a, b;
  for (int i = 0; i < n - 1; ++i) {
    cin >> a >> b;
    --a;
    --b;
    g1[a].push_back(b);
    g1[b].push_back(a);
  }
  for (int i = 0; i < n - 1; ++i) {
    cin >> a >> b;
    --a;
    --b;
    g2[a].emplace_back(b, 1);
    g2[b].emplace_back(a, 1);
  }

  tin.resize(n);
  tout.resize(n);
  euler.resize(n);
  dist1n.resize(n);
  dist2n.resize(n);

  dt(0, -1);

  dfs1(0, -1, 0);

  vector<int32_t> dist1(n), dist2(n);
  for (int i = 0; i < n; ++i) {
    dist1[i] = dist1n[euler[i]];
  }

  u128 ans = 0;
  reroot(0, -1, dist1, dist2, ans);

  uint64_t result = (uint64_t)(ans & (((u128)1 << 64) - 1));
  cout << result << "\n";
  return 0;
}
0