結果

問題 No.3194 Do Optimize Your Solution
ユーザー 👑 tatyam
提出日時 2025-06-24 19:24:56
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
MLE  
実行時間 -
コード長 8,978 bytes
コンパイル時間 2,787 ms
コンパイル使用メモリ 233,128 KB
実行使用メモリ 821,828 KB
最終ジャッジ日時 2025-06-27 20:52:08
合計ジャッジ時間 8,345 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1 MLE * 1
other MLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using u32 = uint32_t;
using u64 = uint64_t;

// --- static_modint<MOD> の定義 ---
template<uint32_t m> class static_modint {
    using mint = static_modint;
    uint32_t _v = 0;
    static const bool prime;
    static constexpr pair<int32_t, int32_t> inv_gcd(int32_t a, int32_t b) {
        if (a == 0) return {b, 0};
        int32_t s = b, t = a, m0 = 0, m1 = 1;
        while (t) {
            int32_t u = s / t;
            s -= t * u;  m0 -= m1 * u;
            swap(s, t);  swap(m0, m1);
        }
        if (m0 < 0) m0 += b / s;
        return {s, m0};
    }
public:
    static constexpr mint raw(uint32_t v) {
        mint a; a._v = v; return a;
    }
    constexpr static_modint() = default;
    template<class T>
    constexpr static_modint(T v) {
        static_assert(is_integral<T>::value, "T is not integral");
        if constexpr (is_signed<T>::value) {
            int64_t x = int64_t(v % int64_t(m));
            if (x < 0) x += m;
            _v = uint32_t(x);
        } else {
            _v = uint32_t(v % m);
        }
    }
    static constexpr uint32_t mod() { return m; }
    constexpr uint32_t val() const { return _v; }
    constexpr mint& operator++() { return *this += 1; }
    constexpr mint& operator--() { return *this -= 1; }
    constexpr mint operator++(int) { mint tmp = *this; ++*this; return tmp; }
    constexpr mint operator--(int) { mint tmp = *this; --*this; return tmp; }
    constexpr mint& operator+=(mint rhs) {
        if (_v >= m - rhs._v) _v -= m;
        _v += rhs._v;
        return *this;
    }
    constexpr mint& operator-=(mint rhs) {
        if (_v < rhs._v) _v += m;
        _v -= rhs._v;
        return *this;
    }
    constexpr mint& operator*=(mint rhs) { return *this = *this * rhs; }
    constexpr mint& operator/=(mint rhs) { return *this *= rhs.inv(); }
    constexpr mint operator+() const { return *this; }
    constexpr mint operator-() const { return mint{} - *this; }
    constexpr mint pow(long long n) const {
        assert(n >= 0);
        mint x = *this, r = 1;
        while (n > 0) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    constexpr mint inv() const {
        if (prime) {
            assert(_v != 0);
            return pow(m - 2);
        } else {
            auto eg = inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }
    friend constexpr mint operator+(mint a, mint b) { return a += b; }
    friend constexpr mint operator-(mint a, mint b) { return a -= b; }
    friend constexpr mint operator*(mint a, mint b) { return uint64_t(a._v) * b._v; }
    friend constexpr mint operator/(mint a, mint b) { return a /= b; }
    friend constexpr bool operator==(mint a, mint b) { return a._v == b._v; }
    friend constexpr bool operator!=(mint a, mint b) { return a._v != b._v; }
};
template<uint32_t m>
constexpr bool static_modint<m>::prime = [](){
    if (m < 2) return false;
    if (m == 2 || m == 7 || m == 61) return true;
    if (m % 2 == 0) return false;
    uint32_t d = m - 1;
    while ((d & 1) == 0) d >>= 1;
    for (uint32_t a : {2u, 7u, 61u}) {
        if (a % m == 0) continue;
        auto y = static_modint<m>(a).pow(d);
        uint32_t t = d;
        while (t != m - 1 && y != 1 && y != static_modint<m>(m - 1)) {
            y *= y;
            t <<= 1;
        }
        if (y != static_modint<m>(m - 1) && (t & 1) == 0) return false;
    }
    return true;
}();

using mint = static_modint<1000000007>;
istream& operator>>(istream& in, mint& x) { long long a; in >> a; x = a; return in; }
ostream& operator<<(ostream& out, mint x) { return out << x.val(); }
constexpr mint operator""_M(unsigned long long x) { return static_cast<mint>(x); }

void solve() {
    u64 N;
    cin >> N;
    vector<tuple<u32,u32,u32>> edges(N-1);
    vector<int> deg(N, 0);
    for (int i = 0; i < int(N)-1; i++) {
        u32 u, v, w = 1;
        cin >> u >> v;
        --u; --v;
        edges[i] = {u,v,w};
        deg[u]++; deg[v]++;
    }
    vector<vector<pair<u32,u32>>> A(N);
    for (int i = 0; i < N; i++) A[i].reserve(deg[i]);
    for (auto& e : edges) {
        auto [u,v,w] = e;
        A[u].emplace_back(v,w);
        A[v].emplace_back(u,w);
    }
    fill(deg.begin(), deg.end(), 0);
    for (int i = 0; i < int(N)-1; i++) {
        u32 u, v, w = 1;
        tie(u,v,w) = edges[i];
        cin >> u >> v;
        --u; --v;
        edges[i] = {u,v,w};
        deg[u]++; deg[v]++;
    }
    vector<vector<pair<u32,u32>>> B(N);
    for (int i = 0; i < N; i++) B[i].reserve(deg[i]);
    for (auto& e : edges) {
        auto [u,v,w] = e;
        B[u].emplace_back(v,w);
        B[v].emplace_back(u,w);
    }

    // Heavy-Light Decomposition on A
    vector<u32> siz(N,1), heavy_parent(N, u32(-1)), dist_parent(N,0);
    auto hld = [&](auto&& self, int v) -> void {
        if (A[v].empty()) return;
        for (auto [ch,w] : A[v]) {
            auto& vec = A[ch];
            for (auto it = vec.begin(); it != vec.end(); ++it) {
                if (it->first == v && it->second == w) {
                    vec.erase(it);
                    break;
                }
            }
            dist_parent[ch] = w;
            self(self, ch);
            siz[v] += siz[ch];
        }
        int best = 0;
        for (int i = 1; i < (int)A[v].size(); i++) {
            if (siz[A[v][i].first] > siz[A[v][best].first]) {
                best = i;
            }
        }
        u32 hv = A[v][best].first;
        if (v != 0) heavy_parent[hv] = v;
        swap(A[v][0], A[v][best]);
    };
    hld(hld, 0);

    // Centroid Decomposition on B
    struct T { mint cnt = 0, sum = 0; };
    vector<vector<T>> sum_dist(N);
    for (int i = 0; i < N; i++) sum_dist[i].resize(B[i].size()+1);
    int LOG = 0; while ((1u<<LOG) <= N) ++LOG;
    vector<vector<tuple<T*,T*,mint>>> cent(N);
    for (auto& v : cent) v.reserve(LOG+1);
    vector<bool> deleted(N,false);
    vector<pair<int,int>> cc;

    auto dfs_size = [&](auto&& self, int p, int v) -> int {
        cc.emplace_back(v,0);
        int s = 1;
        for (auto [nx,w] : B[v]) if (nx!=p && !deleted[nx]) {
            s += self(self, v, nx);
        }
        cc.back().second = s;
        return s;
    };

    auto build = [&](auto&& self, int entry) -> void {
        cc.clear();
        int total = dfs_size(dfs_size, -1, entry);
        int half = total/2;
        int best_i = 0;
        for (int i = 1; i < (int)cc.size(); i++) {
            if (cc[i].second >= half && cc[i].second < cc[best_i].second)
                best_i = i;
        }
        int v = cc[best_i].first;
        deleted[v] = true;

        T* P = &sum_dist[v].back();
        cent[v].emplace_back(P, nullptr, mint(0));
        for (int i = 0; i < (int)B[v].size(); i++) {
            auto [u,w] = B[v][i];
            if (deleted[u]) continue;
            T* Q = &sum_dist[v][i];
            auto dfs = [&](auto&& self2, int p2, int x, mint d) -> void {
                cent[x].emplace_back(P, Q, d);
                for (auto [y,ww] : B[x]) {
                    if (y!=p2 && !deleted[y]) self2(self2, x, y, d+ww);
                }
            };
            dfs(dfs, v, u, mint(w));
        }
        for (auto [u,w] : B[v]) if (!deleted[u]) {
            self(self, u);
        }
    };
    build(build, 0);

    // 初期化
    for (int v = 0; v < N; v++) {
        for (auto& t : cent[v]) {
            auto [P,Q,d] = t;
            P->cnt += 1;  P->sum += d;
            if (Q) { Q->cnt += 1;  Q->sum += d; }
        }
    }

    vector<bool> active(N,false);
    mint ans = 0, cur = 0;

    auto add = [&](int v, mint x) {
        active[v] = !active[v];
        for (auto& t : cent[v]) {
            auto [P,Q,d] = t;
            // P側
            P->cnt -= x;  P->sum -= x*d;
            cur += x*(P->sum + P->cnt*d);
            P->cnt -= x;  P->sum -= x*d;
            if (!Q) break;
            // Q側
            Q->cnt -= x;  Q->sum -= x*d;
            cur -= x*(Q->sum + Q->cnt*d);
            Q->cnt -= x;  Q->sum -= x*d;
        }
    };

    auto set_sub = [&](auto&& self, int v) -> void {
        add(v, 1);
        for (auto [ch,w] : A[v]) self(self, ch);
    };
    auto reset_sub = [&](auto&& self, int v) -> void {
        add(v, mint(-1));
        for (auto [ch,w] : A[v]) self(self, ch);
    };
    auto climb = [&](auto&& self, int v) -> void {
        add(v, 1);
        for (int i = 1; i < (int)A[v].size(); i++)
            set_sub(set_sub, A[v][i].first);
        ans += cur * dist_parent[v];
        int p = heavy_parent[v];
        if (p != int(u32(-1))) self(self, p);
        else reset_sub(reset_sub, v);
    };

    for (int i = 0; i < N; i++) {
        if (A[i].empty()) climb(climb, i);
    }

    cout << ans * 2 << "\n";
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T = 1;
    while (T--) solve();
    return 0;
}
0