結果

問題 No.922 東北きりきざむたん
ユーザー nebukuro09
提出日時 2019-11-08 23:01:18
言語 D
(dmd 2.109.1)
結果
AC  
実行時間 428 ms / 2,000 ms
コード長 4,375 bytes
コンパイル時間 896 ms
コンパイル使用メモリ 120,676 KB
実行使用メモリ 48,204 KB
最終ジャッジ日時 2024-06-22 03:01:36
合計ジャッジ時間 8,154 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 26
権限があれば一括ダウンロードができます

ソースコード

diff #

module simple;import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, std.bitmanip, core.stdc.string;

void main() {
    auto s = readln.split.map!(to!int);
    auto N = s[0];
    auto M = s[1];
    auto Q = s[2];
    auto G = new int[][](N+1);
    auto uf = new UnionFind(N);
    foreach (_; 0..M) {
        s = readln.split.map!(to!int);
        G[s[0]-1] ~= s[1]-1;
        G[s[1]-1] ~= s[0]-1;
        uf.unite(s[0]-1, s[1]-1);
    }
    foreach (i; 0..N) if (uf.table[i] < 0) G[N] ~= i, G[i] ~= N;
    auto lca = new LowestCommonAncestor(G, N);

    long ans = 0;
    auto V = new long[](N);
    auto subsum = new long[](N);
    auto subnum = new long[](N);
    auto dp = new long[](N);

    while (Q--) {
        s = readln.split.map!(to!int);
        auto a = s[0] - 1;
        auto b = s[1] - 1;
        if (uf.find(a) == uf.find(b)) {
            ans += lca.dist(a, b);
        } else {
            V[a] += 1;
            V[b] += 1;
        }
    }

    void dfs1(int n, int p) {
        subnum[n] = V[n];
        foreach (m; G[n]) if (m != p && m != N) {
            dfs1(m, n);
            subsum[n] += subnum[m] + subsum[m];
            subnum[n] += subnum[m];
        }
    }

    void dfs2(int n, int p, long allnum) {
        if (p == -1) {
            dp[n] = subsum[n];
        }
        foreach (m; G[n]) if (m != p && m != N) {
            dp[m] = dp[n] - (subnum[m]) + (allnum - subnum[m]);
            dfs2(m, n, allnum);
        }
    }

    foreach (i; 0..N) if (uf.find(i) == i) {
        dfs1(i, -1);
        long allsum = 0;
        foreach (j; uf.group[i]) allsum += V[j];
        dfs2(i, -1, allsum);
        long mn = 1L << 59;
        foreach (j; uf.group[i]) mn = min(mn, dp[j]);
        ans += mn;
    }


    ans.writeln;
}

class UnionFind {
    import std.algorithm : swap;

    int n;
    int[] table;
    int[][] group;

    this(int n) {
        this.n = n;
        table = new int[](n);
        table[] = -1;
        group = new int[][](n);
        foreach (i; 0..n) group[i] = [i];
    }

    int find(int x) {
        return table[x] < 0 ? x : (table[x] = find(table[x]));
    }

    void unite(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return;
        if (table[x] > table[y]) swap(x, y);
        group[x] ~= group[y];
        group[y] = [];
        table[x] += table[y];
        table[y] = x;
    }

    bool same(int x, int y) {
        return find(x) == find(y);
    }
}

// This is O(log(n)) implementation with doubling.

class LowestCommonAncestor {
    import std.algorithm : swap;
    import std.conv : to;
    import std.typecons : Tuple, tuple;
    import core.bitop : bsr;

    int n, root, lgn;
    int[][] graph;
    int[] depth;
    int[][] dp;

    this(const int[][] graph, int root) {
        n = graph.length.to!int;
        this.root = root;
        this.graph = new int[][](n);
        foreach (i; 0..n) this.graph[i] = graph[i].dup;

        lgn = bsr(n) + 3;
        depth = new int[](n);
        dp = new int[][](n, lgn);

        construct;
    }

    int lca(int a, int b) {
        if (depth[a] < depth[b]) swap(a, b);

        int diff = depth[a] - depth[b];
        foreach_reverse (i; 0..lgn) if (diff & (1<<i)) a = dp[a][i];

        if (a == b) return a;

        int K = lgn;
        while (dp[a][0] != dp[b][0]) {
            foreach_reverse (k; 0..lgn) {
                if (dp[a][k] != dp[b][k]) {
                    a = dp[a][k];
                    b = dp[b][k];
                    K = k;
                }
            }
        }

        return dp[a][0];
    }

    int dist(int u, int v) {
        return depth[u] + depth[v] - 2 * depth[lca(u, v)];
    }

    private void construct() {
        auto stack = new Tuple!(int, int, int)[](n+10);
        int sp = 0;
        stack[0] = tuple(root, -1, 0);
        while (sp >= 0) {
            auto u = stack[sp][0];
            auto p = stack[sp][1];
            auto d = stack[sp][2];
            sp -= 1;
            dp[u][0] = p;
            depth[u] = d;
            foreach (v; graph[u]) if (v != p) stack[++sp] = tuple(v, u, d+1);
        }

        foreach (k; 0..lgn-1)
            foreach (i; 0..n)
                dp[i][k+1] = (dp[i][k] == -1) ? -1 : dp[dp[i][k]][k];
    }
}
0