結果

問題 No.2337 Equidistant
ユーザー SnowBeenDidingSnowBeenDiding
提出日時 2023-06-03 00:41:05
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 1,546 ms / 4,000 ms
コード長 15,139 bytes
コンパイル時間 6,458 ms
コンパイル使用メモリ 328,064 KB
実行使用メモリ 67,336 KB
最終ジャッジ日時 2024-06-09 02:50:59
合計ジャッジ時間 29,457 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,816 KB
testcase_01 AC 1 ms
6,812 KB
testcase_02 AC 2 ms
6,812 KB
testcase_03 AC 2 ms
6,944 KB
testcase_04 AC 2 ms
6,940 KB
testcase_05 AC 2 ms
6,940 KB
testcase_06 AC 11 ms
6,940 KB
testcase_07 AC 12 ms
6,940 KB
testcase_08 AC 12 ms
6,940 KB
testcase_09 AC 11 ms
6,940 KB
testcase_10 AC 12 ms
6,940 KB
testcase_11 AC 1,232 ms
42,644 KB
testcase_12 AC 1,215 ms
42,716 KB
testcase_13 AC 1,202 ms
42,716 KB
testcase_14 AC 1,222 ms
42,840 KB
testcase_15 AC 1,230 ms
42,704 KB
testcase_16 AC 1,195 ms
42,840 KB
testcase_17 AC 1,205 ms
42,576 KB
testcase_18 AC 1,266 ms
42,692 KB
testcase_19 AC 1,221 ms
42,632 KB
testcase_20 AC 1,228 ms
42,708 KB
testcase_21 AC 1,402 ms
67,336 KB
testcase_22 AC 987 ms
40,776 KB
testcase_23 AC 1,141 ms
41,556 KB
testcase_24 AC 1,546 ms
60,140 KB
testcase_25 AC 1,191 ms
41,548 KB
testcase_26 AC 1,374 ms
59,932 KB
testcase_27 AC 972 ms
41,680 KB
testcase_28 AC 965 ms
41,480 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;
const long long MOD = 998244353;
// using mint = modint1000000007;
// const long long MOD = 1000000007;
// using mint = modint;//mint::set_mod(MOD);

#include <bits/stdc++.h>
#define rep(i, a, b) for (ll i = (ll)(a); i < (ll)(b); i++)
#define repeq(i, a, b) for (ll i = (ll)(a); i <= (ll)(b); i++)
#define repreq(i, a, b) for (ll i = (ll)(a); i >= (ll)(b); i--)
#define endl '\n'  // fflush(stdout);
#define cYes cout << "Yes" << endl
#define cNo cout << "No" << endl
#define sortr(v) sort(v, greater<>())
#define pb push_back
#define pob pop_back
#define mp make_pair
#define mt make_tuple
#define FI first
#define SE second
#define ALL(v) (v).begin(), (v).end()
#define INFLL 3000000000000000100LL
#define INF 1000000100
#define PI acos(-1.0L)
#define TAU (PI * 2.0L)

using namespace std;

typedef long long ll;
typedef pair<ll, ll> Pll;
typedef tuple<ll, ll, ll> Tlll;
typedef vector<int> Vi;
typedef vector<Vi> VVi;
typedef vector<ll> Vl;
typedef vector<Vl> VVl;
typedef vector<VVl> VVVl;
typedef vector<Tlll> VTlll;
typedef vector<mint> Vm;
typedef vector<Vm> VVm;
typedef vector<string> Vs;
typedef vector<double> Vd;
typedef vector<char> Vc;
typedef vector<bool> Vb;
typedef vector<Pll> VPll;
typedef priority_queue<ll> PQl;
typedef priority_queue<ll, vector<ll>, greater<ll>> PQlr;

/* inout */
ostream &operator<<(ostream &os, mint const &m) {
    os << m.val();
    return os;
}
istream &operator>>(istream &is, mint &m) {
    long long n;
    is >> n, m = n;
    return is;
}
template <typename T>
ostream &operator<<(ostream &os, const vector<T> &v) {
    int n = v.size();
    rep(i, 0, n) { os << v[i] << " \n"[i == n - 1]; }
    return os;
}
template <typename T>
ostream &operator<<(ostream &os, const vector<vector<T>> &v) {
    int n = v.size();
    rep(i, 0, n) os << v[i];
    return os;
}
template <typename T, typename S>
ostream &operator<<(ostream &os, pair<T, S> const &p) {
    os << p.first << ' ' << p.second;
    return os;
}
template <typename T, typename S>
ostream &operator<<(ostream &os, const map<T, S> &mp) {
    for (auto &[key, val] : mp) {
        os << key << ':' << val << '\n';
    }
    return os;
}
template <typename T>
ostream &operator<<(ostream &os, const set<T> &st) {
    auto itr = st.begin();
    for (int i = 0; i < (int)st.size(); i++) {
        os << *itr << (i + 1 != (int)st.size() ? ' ' : '\n');
        itr++;
    }
    return os;
}
template <typename T>
ostream &operator<<(ostream &os, multiset<T> &st) {
    auto itr = st.begin();
    for (int i = 0; i < (int)st.size(); i++) {
        os << *itr << (i + 1 != (int)st.size() ? ' ' : '\n');
        itr++;
    }
    return os;
}
template <typename T>
ostream &operator<<(ostream &os, queue<T> q) {
    while (q.size()) {
        os << q.front();
        q.pop();
        os << " \n"[q.empty()];
    }
    return os;
}
template <typename T>
ostream &operator<<(ostream &os, stack<T> st) {
    vector<T> v;
    while (st.size()) {
        v.push_back(st.top());
        st.pop();
    }
    reverse(ALL(v));
    os << v;
    return os;
}
template <class T, class Container, class Compare>
ostream &operator<<(ostream &os, priority_queue<T, Container, Compare> pq) {
    vector<T> v;
    while (pq.size()) {
        v.push_back(pq.top());
        pq.pop();
    }
    os << v;
    return os;
}
template <typename T>
istream &operator>>(istream &is, vector<T> &v) {
    for (T &in : v) is >> in;
    return is;
}
template <typename T1, typename T2>
istream &operator>>(istream &is, pair<T1, T2> &p) {
    is >> p.first >> p.second;
    return is;
}

/* useful */
template <typename T>
int SMALLER(vector<T> &a, T x) {
    return lower_bound(a.begin(), a.end(), x) - a.begin();
}
template <typename T>
int orSMALLER(vector<T> &a, T x) {
    return upper_bound(a.begin(), a.end(), x) - a.begin();
}
template <typename T>
int BIGGER(vector<T> &a, T x) {
    return a.size() - orSMALLER(a, x);
}
template <typename T>
int orBIGGER(vector<T> &a, T x) {
    return a.size() - SMALLER(a, x);
}
template <typename T>
int COUNT(vector<T> &a, T x) {
    return upper_bound(ALL(a), x) - lower_bound(ALL(a), x);
}
template <typename T, typename S>
bool chmax(T &a, S b) {
    if (a < b) {
        a = b;
        return 1;
    }
    return 0;
}
template <typename T, typename S>
bool chmin(T &a, S b) {
    if (a > b) {
        a = b;
        return 1;
    }
    return 0;
}
template <typename T>
void press(T &v) {
    v.erase(unique(ALL(v)), v.end());
}
template <typename T>
vector<int> zip(vector<T> b) {
    pair<T, int> p[b.size() + 10];
    int a = b.size();
    vector<int> l(a);
    for (int i = 0; i < a; i++) p[i] = mp(b[i], i);
    sort(p, p + a);
    int w = 0;
    for (int i = 0; i < a; i++) {
        if (i && p[i].first != p[i - 1].first) w++;
        l[p[i].second] = w;
    }
    return l;
}
template <typename T>
vector<T> vis(vector<T> &v) {
    vector<T> S(v.size() + 1);
    rep(i, 1, S.size()) S[i] += v[i - 1] + S[i - 1];
    return S;
}

ll dem(ll a, ll b) { return ((a + b - 1) / (b)); }
ll dtoll(double d, int g) { return round(d * pow(10, g)); }
string tobin(ll n, ll d) {
    string ret;
    rep(i, 0, d) {
        ret += (n % 2) ? '1' : '0';
        n /= 2;
    }
    reverse(ALL(ret));
    return ret;
}

const double EPS = 1e-10;

void init() {
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
    cout << fixed << setprecision(12);
}

// do {} while (next_permutation(ALL(vec)));

/********************************** START **********************************/

void sol();

int main() {
    init();
    int q = 1;
    // cin >> q;
    while (q--) sol();
    return 0;
}

/********************************** SOLVE **********************************/

template <typename T>
struct Tree {
    int n, root;
    bool edgeCost, lca_ok;
    vector<vector<int>> gr, chi, par_exp;
    vector<int> al, par, dep, subsize, euler;
    vector<vector<pair<int, T>>> gr_edge_cost;
    vector<tuple<int, int, int, int, T>> edge_info;

    /* HL分解用 */
    vector<vector<int>> gr_hld;
    vector<int> stsize, pathtop, in, out;

    Tree() {}

    Tree(int n = 0, bool edgeCost = false) : n(n), edgeCost(edgeCost) {
        lca_ok = false;
        gr = vector<vector<int>>(n);
        if (edgeCost) {
            gr_edge_cost.resize(n);
        }
    }
    Tree(vector<vector<int>> input) {
        lca_ok = false;
        n = input.size();
        gr = vector<vector<int>>(n);
        int nw = 0;
        vector<bool> al(n);
        al[nw] = 1;
        queue<int> q;
        q.push(nw);
        while (!q.empty()) {
            nw = q.front();
            q.pop();
            for (auto e : input[nw]) {
                if (!al[e]) {
                    connect(nw, e);
                    q.push(e);
                    al[e] = 1;
                }
            }
        }
    }
    void connect(int s, int t) {
        gr[s].push_back(t);
        gr[t].push_back(s);
    }
    void connect(int s, int t, T c) {
        edge_info.push_back({s, t, gr[s].size(), gr[t].size(), c});
        gr[s].push_back(t);
        gr[t].push_back(s);
        gr_edge_cost[s].push_back({t, c});
        gr_edge_cost[t].push_back({s, c});
    }
    void build_dfs(int d) {
        int sum_subsize = 0;
        for (auto &p : gr[d]) {
            if (al[p]) continue;
            al[p] = 1;
            par[p] = d;
            dep[p] = dep[d] + 1;
            chi[d].push_back(p);
            euler.push_back(d);
            build_dfs(p);
            sum_subsize += subsize[p];
        }
        subsize[d] = sum_subsize + 1;
        euler.push_back(d);
    }
    void make(int _root) {
        root = _root;
        rep(i, 0, n) sort(ALL(gr[i]));
        al = par = subsize = dep = vector<int>(n, 0);
        chi = vector<vector<int>>(n);
        euler = {};
        al[root] = 1;
        par[root] = -1;
        dep[root] = 0;
        build_dfs(root);
    }
    pair<int, int> rad() {
        int max_ind = -1, max_dep = -1;
        int s, t;
        make(0);
        rep(i, 0, n) if (chmax(max_dep, dep[i])) max_ind = i;
        make(max_ind);
        s = max_ind;
        max_ind = -1, max_dep = -1;
        rep(i, 0, n) if (chmax(max_dep, dep[i])) max_ind = i;
        t = max_ind;
        swap(s, t);
        // return max_dep;
        return make_pair(s, t);
    }
    void init_lca() {
        int ct = 0;
        while (n >= (1 << ct)) ct++;
        ct++;
        par_exp = vector<vector<int>>(ct, vector<int>(n, -1));
        rep(i, 0, n) par_exp[0][i] = par[i];
        rep(i, 0, ct - 1) {
            rep(j, 0, n) {
                if (par_exp[i][j] == -1)
                    par_exp[i + 1][j] = -1;
                else
                    par_exp[i + 1][j] = par_exp[i][par_exp[i][j]];
            }
        }
        lca_ok = true;
    }
    int lca(int u, int v) {
        if (!lca_ok) init_lca();
        int h = par_exp.size();
        if (dep[u] > dep[v]) swap(u, v);
        for (int i = h - 1; i >= 0; i--) {
            if (((dep[v] - dep[u]) >> i) & 1) v = par_exp[i][v];
        }
        if (u == v) return u;
        for (int i = h - 1; i >= 0; i--) {
            if (par_exp[i][u] != par_exp[i][v]) {
                u = par_exp[i][u];
                v = par_exp[i][v];
            }
        }
        return par_exp[0][u];
    }
    int up(int v, int k) {
        // 頂点 v から k 回親をたどる
        if (!lca_ok) init_lca();
        if (dep[v] < k) return -1;
        int h = par_exp.size();
        for (int i = h - 1; i >= 0; i--) {
            if ((k >> i) & 1) v = par_exp[i][v];
        }
        return v;
    }
    vector<int> mid_point(int s, int t) {
        if (!lca_ok) init_lca();
        if (dep[s] < dep[t]) swap(s, t);
        int p = lca(s, t);
        int dis = abs(dep[s] - dep[p]) + abs(dep[t] - dep[p]);
        int ue = dis / 2;

        if (dis % 2 == 0) return {up(s, ue)};
        return {up(s, ue), up(s, ue + 1)};
        // if (dis % 2 == 0) return up(s, ue);
        // return -1;
    }
    void BuildStsize(int u, int p) {
        stsize[u] = 1, par[u] = p;
        for (int &v : gr_hld[u]) {
            if (v == p) {
                if (v == gr_hld[u].back())
                    break;
                else
                    swap(v, gr_hld[u].back());
            }
            BuildStsize(v, u);
            stsize[u] += stsize[v];
            if (stsize[v] > stsize[gr_hld[u][0]]) {
                swap(v, gr_hld[u][0]);
            }
        }
    }
    void BuildPath(int u, int p, int &tm) {
        in[u] = tm++;
        for (int v : gr_hld[u]) {
            if (v == p) continue;
            pathtop[v] = (v == gr_hld[u][0] ? pathtop[u] : v);
            BuildPath(v, u, tm);
        }
        out[u] = tm;
    }
    void init_hld() {
        gr_hld = gr;
        stsize = pathtop = in = out = vector<int>(n);
        if (!par_exp.size()) init_lca();
        int tm = 0;
        BuildStsize(root, -1);
        pathtop[root] = root;
        BuildPath(root, -1, tm);
    }
    vector<pair<int, int>> path_query(int a, int b, bool edgeQuery = false) {
        int pta = pathtop[a], ptb = pathtop[b];
        vector<pair<int, int>> ret;
        while (pathtop[a] != pathtop[b]) {
            if (in[pta] > in[ptb]) {
                ret.push_back({in[pta], in[a] + 1});
                a = par[pta], pta = pathtop[a];
            } else {
                ret.push_back({in[ptb], in[b] + 1});
                b = par[ptb], ptb = pathtop[b];
            }
        }
        if (in[a] > in[b]) swap(a, b);
        ret.push_back({in[a], in[b] + 1});
        if (edgeQuery) {
            int c = lca(a, b);
            for (auto &p : ret) {
                if (p.first == in[c]) p.first++;
            }
        }
        return ret;
    }
    vector<pair<int, int>> subtree_query(int d, bool edgeQuery = false) {
        return {{in[d] + edgeQuery, out[d]}};
    }
    int hld_ind(int d) { return in[d]; }
    void bfs(int root) {  // bfs template
        al = vector<int>(n);
        al[root] = 1;
        queue<int> q;
        stack<int> st;
        q.push(root);
        st.push(root);
        vector<int> order;
        while (!q.empty()) {
            int d = q.front();
            order.pb(d);
            // 今見ている頂点
            q.pop();
            for (auto &p : gr[d]) {
                if (al[p]) continue;
                al[p] = 1;
                q.push(p);
                st.push(p);
                // ここで次の頂点を見ている
            }
        }
        while (!st.empty()) {
            int d = st.top();
            st.pop();
            order.pb(d);
            // ここでbfsの逆順に頂点をたどっている
        }
    }
    void dfs(int root) {
        val.resize(n);
        al = vector<int>(n);
        al[root] = 1;
        dfs_(root);
    }
    Vl val;
    ll nw = 0;
    void dfs_(int d) {  // dfs template
        if (edgeCost) {
            val[d] = nw;
            nw = 0;
            for (auto &[p, c] : gr_edge_cost[d]) {
                if (al[p]) continue;
                nw += c;
                al[p] = 1;
                dfs_(p);
            }
            return;
        }
    }
};
/* 使い方 */
/*
   Tree<int> tr(n)     // 宣言
   tr.connect(a,b)     // 頂点 a と b をつなぐ
   tr.make(d)          // 頂点 d を根としてpar,dep,chiを作り上げる
   tr.rad()            // 木の直径を返す

   tr.lca(a,b)         // 頂点 a と b の共通の"先祖"で一番近い"頂点"を返す
   tr.up(v,k)          // 頂点 v から k 回上に行った頂点を返す
   tr.mid_point(a, b)  // 頂点 a と b の中間の頂点を返す(返り値vector<int>型)

   tr.init_hld()       // HL分解の前準備
   tr.hld_ind(d)       // 頂点 d のHL分解上の index を返す
   tr.path_query(s,t)  // sからtへのパス上を示す配列を返す
   tr.subtree_query(d) // sの部分木を示す配列を返す
*/

void sol() {
    int n, q;
    cin >> n >> q;
    Tree<int> tr(n);
    rep(i, 0, n - 1) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        tr.connect(a, b);
    }
    tr.make(0);
    tr.init_lca();
    while (q--) {
        int qw, er;
        cin >> qw >> er;
        qw--, er--;
        auto midv = tr.mid_point(qw, er);
        cerr << midv;
        if (midv.size() > 1) {
            cout << 0 << endl;
            continue;
        }
        int mid = midv[0];
        int ans = tr.subsize[mid];
        if (tr.dep[qw] < tr.dep[er]) swap(qw, er);
        int ue = tr.dep[qw] - tr.dep[mid];
        if (tr.dep[qw] == tr.dep[er]) {
            int ans = n;
            ans -= tr.subsize[tr.up(qw, ue - 1)];
            ans -= tr.subsize[tr.up(er, ue - 1)];
            cout << ans << endl;
            continue;
        }

        if (tr.lca(mid, qw) == mid) {
            int ue2 = tr.up(qw, ue - 1);
            ans -= tr.subsize[ue2];
        }
        swap(qw, er);
        if (tr.lca(mid, qw) == mid) {
            int ue2 = tr.up(qw, ue - 1);
            ans -= tr.subsize[ue2];
        }
        cout << ans << endl;
    }
}
0