結果

問題 No.901 K-ary εxtrεεmε
ユーザー gyouzasushigyouzasushi
提出日時 2022-02-11 10:49:54
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 8,248 bytes
コンパイル時間 3,473 ms
コンパイル使用メモリ 238,056 KB
実行使用メモリ 49,008 KB
最終ジャッジ日時 2024-06-27 04:44:15
合計ジャッジ時間 21,712 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 214 ms
49,008 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 4 ms
5,376 KB
testcase_03 AC 3 ms
5,376 KB
testcase_04 AC 4 ms
5,376 KB
testcase_05 AC 3 ms
5,376 KB
testcase_06 AC 4 ms
5,376 KB
testcase_07 AC 1,855 ms
38,656 KB
testcase_08 AC 1,856 ms
38,460 KB
testcase_09 AC 1,851 ms
38,400 KB
testcase_10 AC 1,845 ms
38,404 KB
testcase_11 AC 1,838 ms
38,528 KB
testcase_12 AC 656 ms
38,656 KB
testcase_13 AC 643 ms
38,784 KB
testcase_14 AC 645 ms
38,656 KB
testcase_15 AC 662 ms
38,528 KB
testcase_16 AC 644 ms
38,656 KB
testcase_17 TLE -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (int)(n); i++)
#define rrep(i, n) for (int i = (int)(n - 1); i >= 0; i--)
#define all(x) (x).begin(), (x).end()
#define sz(x) int(x.size())
using namespace std;
using ll = long long;
const int INF = 1e9;
const ll LINF = 1e18;
template <class T>
bool chmax(T& a, const T& b) {
    if (a < b) {
        a = b;
        return 1;
    }
    return 0;
}
template <class T>
bool chmin(T& a, const T& b) {
    if (b < a) {
        a = b;
        return 1;
    }
    return 0;
}
template <class T>
vector<T> make_vec(size_t a) {
    return vector<T>(a);
}
template <class T, class... Ts>
auto make_vec(size_t a, Ts... ts) {
    return vector<decltype(make_vec<T>(ts...))>(a, make_vec<T>(ts...));
}
template <typename T>
istream& operator>>(istream& is, vector<T>& v) {
    for (int i = 0; i < int(v.size()); i++) {
        is >> v[i];
    }
    return is;
}
template <typename T>
ostream& operator<<(ostream& os, const vector<T>& v) {
    for (int i = 0; i < int(v.size()); i++) {
        os << v[i];
        if (i < int(v.size()) - 1) os << ' ';
    }
    return os;
}
#pragma region LowestCommonAncestor
struct StaticRMQ {
public:
    void init(const std::vector<std::pair<int, int>>& _v) {
        _n = int(_v.size()), d.resize(_n), ceil_log2.resize(_n + 1);
        ceil_log2[0] = 0;
        ceil_log2[1] = 0;
        for (int i = 2; i <= _n; i++) ceil_log2[i] = ceil_log2[i >> 1] + 1;
        for (int i = 0; i < _n; i++) {
            d[i].resize(ceil_log2[_n] + 1);
            d[i][0] = _v[i];
        }
        for (int b = 0; b < ceil_log2[_n]; b++) {
            for (int i = 0; i < _n; i++) {
                if (i + (1 << (b + 1)) > _n) break;
                d[i][b + 1] = std::min(d[i][b], d[i + (1 << b)][b]);
            }
        }
    }
    std::pair<int, int> prod(int l, int r) {
        if (!(l < r)) return PINF;
        int b = ceil_log2[r - l];
        return std::min(d[l][b], d[r - (1 << b)][b]);
    }

private:
    int _n;
    std::vector<std::vector<std::pair<int, int>>> d;
    std::vector<int> ceil_log2;
    const std::pair<int, int> PINF = {1 << 30, 1 << 30};
};
struct PlusMinusOneRMQ {
public:
    void init(const std::vector<int>& _v) {
        _n = int(_v.size());
        v = _v;
        s = std::max(1, int(std::log2(_n) / 2));
        B = (_n + s - 1) / s;
        std::vector<std::pair<int, int>> _spt(B);
        pattern.resize(B);
        for (int i = 0; i < _n; i += s) {
            int min_j = i;
            int bit = 0;
            for (int j = i; j < std::min(_n, i + s); j++) {
                if (v[j] < v[min_j]) min_j = j;
                if (j + 1 < std::min(_n, i + s) && v[j] < v[j + 1])
                    bit |= 1 << (j - i);
            }
            _spt[i / s] = {v[min_j], min_j};
            pattern[i / s] = bit;
        }
        sparse_table.init(_spt);

        lookup_table.resize(1 << (s - 1));
        for (int bit = 0; bit < (1 << (s - 1)); bit++) {
            lookup_table[bit].resize(s + 1);
            for (int l = 0; l <= s; l++) {
                lookup_table[bit][l].resize(s + 1, INF);
                int min_ = 0;
                int min_i = l;
                int sum = 0;
                for (int r = l + 1; r <= s; r++) {
                    lookup_table[bit][l][r] = min_i;
                    sum += bit >> (r - 1) & 1 ? 1 : -1;
                    if (sum < min_) {
                        min_ = sum;
                        min_i = r;
                    }
                }
            }
        }
    }
    int prod(int l, int r) {
        int m1 = (l + s - 1) / s;
        int m2 = r / s;
        int l1 = s * m1;
        int r1 = s * m2;
        if (m2 < m1) {
            return lookup_table[pattern[m2]][l - r1][r - r1] + r1;
        }
        int ret = INF;
        if (m1 > 0) {
            ret = argmin(
                ret, lookup_table[pattern[m1 - 1]][s - (l1 - l)][s] + l1 - s);
        }
        ret = argmin(ret, sparse_table.prod(m1, m2).second);
        if (m2 < B) {
            ret = argmin(ret, lookup_table[pattern[m2]][0][r - r1] + r1);
        }
        return ret;
    }

private:
    int _n;
    int s, B;
    StaticRMQ sparse_table;
    std::vector<std::vector<std::vector<int>>> lookup_table;
    std::vector<int> pattern, v;
    const int INF = 1 << 30;
    int argmin(int i, int j) {
        if (i >= INF || j >= INF || v[i] == v[j]) return std::min(i, j);
        return v[i] < v[j] ? i : j;
    }
};
struct LowestCommonAncestor {
public:
    LowestCommonAncestor() {
    }
    LowestCommonAncestor(int n, int root = 0)
        : _n(n), _root(root), g(n), id(n), vs(2 * n - 1), dep(2 * n - 1) {
    }
    void add_edge(int from, int to) {
        assert(0 <= from && from < _n);
        assert(0 <= to && to < _n);
        g[from].push_back(to);
        g[to].push_back(from);
    }
    void build() {
        int k = 0;
        auto dfs = [&](auto dfs, int pos, int pre, int d) -> void {
            id[pos] = k;
            vs[k] = pos;
            dep[k++] = d;
            for (int nxt : g[pos]) {
                if (nxt == pre) continue;
                dfs(dfs, nxt, pos, d + 1);
                vs[k] = pos;
                dep[k++] = d;
            }
        };
        dfs(dfs, _root, -1, 0);
        rmq.init(dep);
    }

    int get(int u, int v) {
        int l = std::min(id[u], id[v]);
        int r = std::max(id[u], id[v]) + 1;
        return vs[rmq.prod(l, r)];
    }
    int get(int u, int v, int r) {
        return get(r, u) ^ get(u, v) ^ get(v, r);
    }
    int depth(int u) {
        return dep[id[u]];
    }
    int dist(int u, int v) {
        return depth(u) + depth(v) - 2 * depth(get(u, v));
    }

    // private:
    int _n, _root;
    std::vector<std::vector<int>> g;
    std::vector<int> id, vs, dep;
    PlusMinusOneRMQ rmq;
};
#pragma endregion
int main() {
    int n;
    cin >> n;
    LowestCommonAncestor g(n);
    vector<vector<int>> g_(n);
    using P = pair<int, int>;
    map<P, ll> w;
    rep(_, n - 1) {
        int u, v;
        ll _w;
        scanf("%d %d %lld", &u, &v, &_w);
        g.add_edge(u, v);
        g_[u].push_back(v);
        g_[v].push_back(u);
        w[P(u, v)] = _w;
        w[P(v, u)] = _w;
    }
    g.build();
    vector<vector<int>> gg(n);
    auto compress = [&](vector<int> x) -> void {
        rep(i, n) gg[i].clear();
        sort(all(x), [&](int u, int v) { return g.id[u] < g.id[v]; });
        vector<int> st = {x[0]};
        int k = sz(x);
        rep(i, k - 1) {
            int w = g.get(x[i], x[i + 1]);
            if (w != x[i]) {
                int pre = st.back();
                st.pop_back();
                while (!st.empty() && g.depth(st.back()) > g.depth(w)) {
                    gg[st.back()].push_back(pre);
                    gg[pre].push_back(st.back());
                    pre = st.back();
                    st.pop_back();
                }
                if (st.empty() || st.back() != w) {
                    st.push_back(w);
                    x.push_back(w);
                }
                gg[w].push_back(pre);
                gg[pre].push_back(w);
            }
            st.push_back(x[i + 1]);
        }
        while (sz(st) > 1) {
            int pre = st.back();
            st.pop_back();
            gg[st.back()].push_back(pre);
            gg[pre].push_back(st.back());
        }
    };
    vector<int> d(n);
    auto dfs0 = [&](auto dfs0, int pos, int pre) -> void {
        for (int nxt : g_[pos]) {
            if (nxt == pre) continue;
            d[nxt] = d[pos] + w[P(pos, nxt)];
            dfs0(dfs0, nxt, pos);
        }
        return;
    };
    dfs0(dfs0, 0, -1);
    auto dist = [&](int u, int v) -> ll {
        return d[u] + d[v] - d[g.get(u, v)] * 2;
    };
    auto dfs = [&](auto dfs, int pos, int pre) -> ll {
        ll ret = 0;
        for (int nxt : gg[pos]) {
            if (nxt == pre) continue;
            ret += dfs(dfs, nxt, pos) + dist(pos, nxt);
        }
        return ret;
    };
    int q;
    cin >> q;
    vector<int> x;
    while (q--) {
        int k;
        scanf("%d", &k);
        x.resize(k);
        rep(i, k) scanf("%d", &x[i]);
        compress(x);
        cout << dfs(dfs, x[0], -1) << '\n';
    }
}
0