結果

問題 No.2337 Equidistant
ユーザー miscalcmiscalc
提出日時 2023-06-02 23:30:10
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 5,642 bytes
コンパイル時間 4,644 ms
コンパイル使用メモリ 272,904 KB
実行使用メモリ 36,528 KB
最終ジャッジ日時 2023-08-28 06:07:38
合計ジャッジ時間 21,360 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,380 KB
testcase_01 WA -
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 AC 858 ms
34,852 KB
testcase_22 AC 676 ms
35,444 KB
testcase_23 WA -
testcase_24 AC 946 ms
34,796 KB
testcase_25 WA -
testcase_26 AC 1,021 ms
34,792 KB
testcase_27 WA -
testcase_28 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pll = pair<ll, ll>;
using tlll = tuple<ll, ll, ll>;
constexpr ll INF = 1LL << 60;
template<class T> bool chmin(T& a, T b) {if (a > b) {a = b; return true;} return false;}
template<class T> bool chmax(T& a, T b) {if (a < b) {a = b; return true;} return false;}
ll safemod(ll A, ll M) {ll res = A % M; if (res < 0) res += M; return res;}
ll divfloor(ll A, ll B) {if (B < 0) A = -A, B = -B; return (A - safemod(A, B)) / B;}
ll divceil(ll A, ll B) {if (B < 0) A = -A, B = -B; return divfloor(A + B - 1, B);}
ll pow_ll(ll A, ll B) {if (A == 0 || A == 1) {return A;} if (A == -1) {return B & 1 ? -1 : 1;} ll res = 1; for (int i = 0; i < B; i++) {res *= A;} return res;}
ll mul_limited(ll A, ll B, ll M = INF) { return B == 0 ? 0 : A > M / B ? M : A * B; }
ll pow_limited(ll A, ll B, ll M = INF) { if (A == 0 || A == 1) {return A;} ll res = 1; for (int i = 0; i < B; i++) {if (res > M / A) return M; res *= A;} return res;}
ll logfloor(ll A, ll B) {assert(A >= 2); ll res = 0; for (ll tmp = 1; tmp <= B / A; tmp *= A) {res++;} return res;}
ll logceil(ll A, ll B) {assert(A >= 2); ll res = 0; for (ll tmp = 1; tmp < B; tmp *= A) {res++;} return res;}
ll arisum_ll(ll a, ll d, ll n) { return n * a + (n & 1 ? ((n - 1) >> 1) * n : (n >> 1) * (n - 1)) * d; }
ll arisum2_ll(ll a, ll l, ll n) { return n & 1 ? ((a + l) >> 1) * n : (n >> 1) * (a + l); }
ll arisum3_ll(ll a, ll l, ll d) { assert((l - a) % d == 0); return arisum2_ll(a, l, (l - a) / d + 1); }
template<class T> void unique(vector<T> &V) {V.erase(unique(V.begin(), V.end()), V.end());}
template<class T> void sortunique(vector<T> &V) {sort(V.begin(), V.end()); V.erase(unique(V.begin(), V.end()), V.end());}
#define FINALANS(A) do {cout << (A) << '\n'; exit(0);} while (false)
template<class T> void printvec(const vector<T> &V) {int _n = V.size(); for (int i = 0; i < _n; i++) cout << V[i] << (i == _n - 1 ? "" : " ");cout << '\n';}
template<class T> void printvect(const vector<T> &V) {for (auto v : V) cout << v << '\n';}
template<class T> void printvec2(const vector<vector<T>> &V) {for (auto &v : V) printvec(v);}
//*
#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;
//using mint = modint1000000007;
//using mint = modint;
//*/

template <typename Cost>
struct Edge
{
  int from, to;
  Cost cost;
  Edge(int s, int t, Cost c = 1) : from(s), to(t), cost(c) {}
  operator int() const { return to; }
};
template <typename Cost>
struct Graph : vector<vector<Edge<Cost>>>
{
  Graph(int n) : vector<vector<Edge<Cost>>>(n) {}
  void add_edge(int s, int t, Cost c = 1) { (*this)[s].emplace_back(s, t, c); }
  void add_edge2(int s, int t, Cost c = 1) { add_edge(s, t, c), add_edge(t, s, c); }
};

template<typename Cost>
struct LowestCommonAncestor : Graph<Cost>
{
  vector<vector<int>> par;
  vector<int> dep;
  vector<Cost> dists;

  LowestCommonAncestor(int n) : Graph<Cost>::Graph(n) {}

  void run(const int root = 0)
  {
    par.resize(log2((*this).size()) + 1);
    for (int i = 0; i < (int)par.size(); i++)
      par[i].resize((*this).size());
    dep.resize((*this).size()), dists.resize((*this).size());
    par[0][root] = -1, dep[root] = 0, dists[root] = 0;
    dfs(root, -1);
    doubling();
  }

  void dfs(int v, int pv)
  {
    //* bfs
    queue<pair<int, int>> que;
    que.emplace(make_pair(v, pv));
    while (!que.empty())
    {
      v = que.front().first, pv = que.front().second;
      que.pop();
      for (auto nv : (*this)[v])
      {
        if (nv == pv)
          continue;
        par[0][nv] = v;
        dep[nv] = dep[v] + 1;
        dists[nv] = dists[v] + nv.cost;
        que.emplace(make_pair(nv, v));
      }
    }
    //*/
    /* dfs
    for (auto nv : (*this)[v])
    {
      if (nv == pv)
        continue;
      par[0][nv] = v;
      dep[nv] = dep[v] + 1;
      dists[nv] = dists[v] + nv.cost;
      dfs(nv, v);
    }
    //*/
  }

  void doubling()
  {
    for (int i = 1; i < (int)par.size(); i++)
    {
      for (int v = 0; v < (int)(*this).size(); v++)
      {
        if (par[i - 1][v] == -1)
          par[i][v] = -1;
        else
          par[i][v] = par[i - 1][par[i - 1][v]];
      }
    }
  }

  int query_ancestor(int v, int k)
  {
    for (int i = (int)par.size() - 1; i >= 0; i--)
    {
      if (k & (1 << i))
        v = par[i][v];
      if (v == -1)
        return -1;
    }
    return v;
  }

  int query_lca(int u, int v)
  {
    if (dep[u] < dep[v])
      swap(u, v);
    
    u = query_ancestor(u, dep[u] - dep[v]);
    if (u == v)
      return u;
    for (int i = (int)par.size() - 1; i >= 0; i--)
    {
      if (par[i][u] != par[i][v])
        u = par[i][u], v = par[i][v];
    }
    return par[0][u];
  }

  int query_dist1(int u, int v)
  {
    return dep[u] + dep[v] - 2 * dep[query_lca(u, v)];
  }
  Cost query_dist(int u, int v)
  {
    return dists[u] + dists[v] - 2 * dists[query_lca(u, v)];
  }

  int query_jump(int s, int t, int k)
  {
    int u = query_lca(s, t);
    if (k <= dep[s] - dep[u])
      return query_ancestor(s, k);
    int k2 = dep[s] + dep[t] - 2 * dep[u] - k;
    return k2 < 0 ? -1 : query_ancestor(t, k2);
  }
};

int main()
{
  ll N, Q;
  cin >> N >> Q;
  LowestCommonAncestor<ll> G(N);
  for (ll i = 0; i < N - 1; i++)
  {
    ll a, b;
    cin >> a >> b;
    a--, b--;
    G.add_edge2(a, b);
  }
  G.run();
  while (Q--)
  {
    ll s, t;
    cin >> s >> t;
    s--, t--;
    ll d = G.query_dist(s, t);
    if (d % 2 != 0)
      cout << 0 << '\n';
    else
    {
      ll u = G.query_jump(s, t, d / 2);
      ll ans = G.at(u).size() - 1;
      cout << ans << '\n';
    }
  }
}
0