結果

問題 No.3272 Separate Contractions
ユーザー apricity
提出日時 2025-07-26 23:47:33
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
CE  
(最新)
AC  
(最初)
実行時間 -
コード長 3,496 bytes
コンパイル時間 1,988 ms
コンパイル使用メモリ 91,784 KB
最終ジャッジ日時 2025-09-12 07:44:29
合計ジャッジ時間 7,327 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
コンパイルエラー時のメッセージ・ソースコードは、提出者また管理者しか表示できないようにしております。(リジャッジ後のコンパイルエラーは公開されます)
ただし、clay言語の場合は開発者のデバッグのため、公開されます。

コンパイルメッセージ
main.cpp:19:61: error: return type 'class std::tuple<int, int, int>' is incomplete
   19 | tuple<int, int, int> get_center(const vector<vector<int>> &g) {
      |                                                             ^
main.cpp: In function 'int main()':
main.cpp:55:10: error: 'void <structured bindings>' has incomplete type
   55 |     auto [diameter, center1, center2] = get_center(g);
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~

ソースコード

diff #

#include <iostream>
#include <vector>
using namespace std;

vector<int> get_distance(const vector<vector<int>> &g, int start) {
    int n = g.size();
    vector<int> dist(n);
    auto dfs = [&] (auto rec, int x, int p) -> void {
        for (int y : g[x]) if (y != p) {
            dist[y] = dist[x] + 1;
            rec(rec, y, x);
        }
    };
    dfs(dfs, start, -1);
    return dist;
}

// (diameter, center1, [center2])
tuple<int, int, int> get_center(const vector<vector<int>> &g) {
    int n = g.size();
    vector<int> dist = get_distance(g, 0);
    int max_dist = -1, argmax = -1;
    for (int i = 0; i < n; i++) if (max_dist < dist[i]) max_dist = dist[i], argmax = i;

    dist = get_distance(g, argmax);
    max_dist = argmax = -1;
    for (int i = 0; i < n; i++) if (max_dist < dist[i]) max_dist = dist[i], argmax = i;

    vector<int> pre(n, -1);
    for (int i = 0; i < n; i++) {
        for (int j : g[i]) if (dist[j] == dist[i] - 1) pre[i] = j;
    }
    vector<int> path;
    int cur = argmax;
    while (cur != -1) {
        path.emplace_back(cur);
        cur = pre[cur];
    }
    if (max_dist % 2) return {max_dist, path[max_dist/2], path[max_dist/2+1]};
    else return {max_dist, path[max_dist/2], -1};
}

int main() {
    int n; cin >> n;
    vector<vector<int>> g(n);
    vector<pair<int, int>> edge(n-1);
    for (int i = 0; i < n-1; i++) {
        int u, v; cin >> u >> v;
        u--; v--;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
        edge[i] = {u, v};
    }

    auto [diameter, center1, center2] = get_center(g);
    int half = diameter / 2;

    vector<int> dep(n), sz(n), cnt(n), top(n);
    auto dfs = [&] (auto rec, int u, int p, int dep_cur, int top_cur) -> void {
        dep[u] = dep_cur;
        top[u] = top_cur;
        sz[u] = 1;
        if (dep[u] == half) cnt[u] = 1;
        for (int v : g[u]) if (v != p) {
            rec(rec, v, u, dep_cur + 1, top_cur);
            cnt[u] += cnt[v];
            sz[u] += sz[v];
        }
    };

    long long ans_before = 0;
    vector<long long> ans(n-1);

    if (diameter == half * 2) {
        int occur_half = 0, sum_sz = 0;
        for (int r : g[center1]) {
            dfs(dfs, r, center1, 1, r);
            if (cnt[r] > 0) occur_half++, sum_sz += sz[r];
        }

        for (int i = 0; i < n; i++) ans_before += dep[i] + half;

        for (int i = 0; i < n-1; i++) {
            auto [u, v] = edge[i];
            if (dep[u] > dep[v]) swap(u, v);

            ans[i] = ans_before;
            ans[i] -= dep[u] + half;
            ans[i] -= sz[v];
            if (occur_half == 2 and cnt[top[v]] > 0 and cnt[top[v]] == cnt[v]) {
                ans[i] -= sum_sz - sz[top[v]];
            }
        }
    }
    else {
        dfs(dfs, center1, center2, 0, center1);
        dfs(dfs, center2, center1, 0, center2);
        for (int i = 0; i < n; i++) ans_before += dep[i] + half + 1;

        for (int i = 0; i < n-1; i++) {
            auto [u, v] = edge[i];
            ans[i] = ans_before;

            if ((u == center1 and v == center2) or (u == center2 and v == center1)) {
                ans[i] -= half + n;
            }
            else{
                if (dep[u] > dep[v]) swap(u, v);
                ans[i] -= dep[u] + half + 1;
                ans[i] -= sz[v];
                if (cnt[top[v]] == cnt[v]) {
                    ans[i] -= n - sz[top[v]];
                }
            }
        }
    }

    for (int i = 0; i < n-1; i++) cout << ans[i] << "\n";
}
0