結果
問題 |
No.3194 Do Optimize Your Solution
|
ユーザー |
👑 ![]() |
提出日時 | 2025-06-24 19:24:56 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
MLE
|
実行時間 | - |
コード長 | 8,978 bytes |
コンパイル時間 | 2,787 ms |
コンパイル使用メモリ | 233,128 KB |
実行使用メモリ | 821,828 KB |
最終ジャッジ日時 | 2025-06-27 20:52:08 |
合計ジャッジ時間 | 8,345 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 MLE * 1 |
other | MLE * 1 -- * 16 |
ソースコード
#include <bits/stdc++.h> using namespace std; using u32 = uint32_t; using u64 = uint64_t; // --- static_modint<MOD> の定義 --- template<uint32_t m> class static_modint { using mint = static_modint; uint32_t _v = 0; static const bool prime; static constexpr pair<int32_t, int32_t> inv_gcd(int32_t a, int32_t b) { if (a == 0) return {b, 0}; int32_t s = b, t = a, m0 = 0, m1 = 1; while (t) { int32_t u = s / t; s -= t * u; m0 -= m1 * u; swap(s, t); swap(m0, m1); } if (m0 < 0) m0 += b / s; return {s, m0}; } public: static constexpr mint raw(uint32_t v) { mint a; a._v = v; return a; } constexpr static_modint() = default; template<class T> constexpr static_modint(T v) { static_assert(is_integral<T>::value, "T is not integral"); if constexpr (is_signed<T>::value) { int64_t x = int64_t(v % int64_t(m)); if (x < 0) x += m; _v = uint32_t(x); } else { _v = uint32_t(v % m); } } static constexpr uint32_t mod() { return m; } constexpr uint32_t val() const { return _v; } constexpr mint& operator++() { return *this += 1; } constexpr mint& operator--() { return *this -= 1; } constexpr mint operator++(int) { mint tmp = *this; ++*this; return tmp; } constexpr mint operator--(int) { mint tmp = *this; --*this; return tmp; } constexpr mint& operator+=(mint rhs) { if (_v >= m - rhs._v) _v -= m; _v += rhs._v; return *this; } constexpr mint& operator-=(mint rhs) { if (_v < rhs._v) _v += m; _v -= rhs._v; return *this; } constexpr mint& operator*=(mint rhs) { return *this = *this * rhs; } constexpr mint& operator/=(mint rhs) { return *this *= rhs.inv(); } constexpr mint operator+() const { return *this; } constexpr mint operator-() const { return mint{} - *this; } constexpr mint pow(long long n) const { assert(n >= 0); mint x = *this, r = 1; while (n > 0) { if (n & 1) r *= x; x *= x; n >>= 1; } return r; } constexpr mint inv() const { if (prime) { assert(_v != 0); return pow(m - 2); } else { auto eg = inv_gcd(_v, m); assert(eg.first == 1); return eg.second; } } friend constexpr mint operator+(mint a, mint b) { return a += b; } friend constexpr mint operator-(mint a, mint b) { return a -= b; } friend constexpr mint operator*(mint a, mint b) { return uint64_t(a._v) * b._v; } friend constexpr mint operator/(mint a, mint b) { return a /= b; } friend constexpr bool operator==(mint a, mint b) { return a._v == b._v; } friend constexpr bool operator!=(mint a, mint b) { return a._v != b._v; } }; template<uint32_t m> constexpr bool static_modint<m>::prime = [](){ if (m < 2) return false; if (m == 2 || m == 7 || m == 61) return true; if (m % 2 == 0) return false; uint32_t d = m - 1; while ((d & 1) == 0) d >>= 1; for (uint32_t a : {2u, 7u, 61u}) { if (a % m == 0) continue; auto y = static_modint<m>(a).pow(d); uint32_t t = d; while (t != m - 1 && y != 1 && y != static_modint<m>(m - 1)) { y *= y; t <<= 1; } if (y != static_modint<m>(m - 1) && (t & 1) == 0) return false; } return true; }(); using mint = static_modint<1000000007>; istream& operator>>(istream& in, mint& x) { long long a; in >> a; x = a; return in; } ostream& operator<<(ostream& out, mint x) { return out << x.val(); } constexpr mint operator""_M(unsigned long long x) { return static_cast<mint>(x); } void solve() { u64 N; cin >> N; vector<tuple<u32,u32,u32>> edges(N-1); vector<int> deg(N, 0); for (int i = 0; i < int(N)-1; i++) { u32 u, v, w = 1; cin >> u >> v; --u; --v; edges[i] = {u,v,w}; deg[u]++; deg[v]++; } vector<vector<pair<u32,u32>>> A(N); for (int i = 0; i < N; i++) A[i].reserve(deg[i]); for (auto& e : edges) { auto [u,v,w] = e; A[u].emplace_back(v,w); A[v].emplace_back(u,w); } fill(deg.begin(), deg.end(), 0); for (int i = 0; i < int(N)-1; i++) { u32 u, v, w = 1; tie(u,v,w) = edges[i]; cin >> u >> v; --u; --v; edges[i] = {u,v,w}; deg[u]++; deg[v]++; } vector<vector<pair<u32,u32>>> B(N); for (int i = 0; i < N; i++) B[i].reserve(deg[i]); for (auto& e : edges) { auto [u,v,w] = e; B[u].emplace_back(v,w); B[v].emplace_back(u,w); } // Heavy-Light Decomposition on A vector<u32> siz(N,1), heavy_parent(N, u32(-1)), dist_parent(N,0); auto hld = [&](auto&& self, int v) -> void { if (A[v].empty()) return; for (auto [ch,w] : A[v]) { auto& vec = A[ch]; for (auto it = vec.begin(); it != vec.end(); ++it) { if (it->first == v && it->second == w) { vec.erase(it); break; } } dist_parent[ch] = w; self(self, ch); siz[v] += siz[ch]; } int best = 0; for (int i = 1; i < (int)A[v].size(); i++) { if (siz[A[v][i].first] > siz[A[v][best].first]) { best = i; } } u32 hv = A[v][best].first; if (v != 0) heavy_parent[hv] = v; swap(A[v][0], A[v][best]); }; hld(hld, 0); // Centroid Decomposition on B struct T { mint cnt = 0, sum = 0; }; vector<vector<T>> sum_dist(N); for (int i = 0; i < N; i++) sum_dist[i].resize(B[i].size()+1); int LOG = 0; while ((1u<<LOG) <= N) ++LOG; vector<vector<tuple<T*,T*,mint>>> cent(N); for (auto& v : cent) v.reserve(LOG+1); vector<bool> deleted(N,false); vector<pair<int,int>> cc; auto dfs_size = [&](auto&& self, int p, int v) -> int { cc.emplace_back(v,0); int s = 1; for (auto [nx,w] : B[v]) if (nx!=p && !deleted[nx]) { s += self(self, v, nx); } cc.back().second = s; return s; }; auto build = [&](auto&& self, int entry) -> void { cc.clear(); int total = dfs_size(dfs_size, -1, entry); int half = total/2; int best_i = 0; for (int i = 1; i < (int)cc.size(); i++) { if (cc[i].second >= half && cc[i].second < cc[best_i].second) best_i = i; } int v = cc[best_i].first; deleted[v] = true; T* P = &sum_dist[v].back(); cent[v].emplace_back(P, nullptr, mint(0)); for (int i = 0; i < (int)B[v].size(); i++) { auto [u,w] = B[v][i]; if (deleted[u]) continue; T* Q = &sum_dist[v][i]; auto dfs = [&](auto&& self2, int p2, int x, mint d) -> void { cent[x].emplace_back(P, Q, d); for (auto [y,ww] : B[x]) { if (y!=p2 && !deleted[y]) self2(self2, x, y, d+ww); } }; dfs(dfs, v, u, mint(w)); } for (auto [u,w] : B[v]) if (!deleted[u]) { self(self, u); } }; build(build, 0); // 初期化 for (int v = 0; v < N; v++) { for (auto& t : cent[v]) { auto [P,Q,d] = t; P->cnt += 1; P->sum += d; if (Q) { Q->cnt += 1; Q->sum += d; } } } vector<bool> active(N,false); mint ans = 0, cur = 0; auto add = [&](int v, mint x) { active[v] = !active[v]; for (auto& t : cent[v]) { auto [P,Q,d] = t; // P側 P->cnt -= x; P->sum -= x*d; cur += x*(P->sum + P->cnt*d); P->cnt -= x; P->sum -= x*d; if (!Q) break; // Q側 Q->cnt -= x; Q->sum -= x*d; cur -= x*(Q->sum + Q->cnt*d); Q->cnt -= x; Q->sum -= x*d; } }; auto set_sub = [&](auto&& self, int v) -> void { add(v, 1); for (auto [ch,w] : A[v]) self(self, ch); }; auto reset_sub = [&](auto&& self, int v) -> void { add(v, mint(-1)); for (auto [ch,w] : A[v]) self(self, ch); }; auto climb = [&](auto&& self, int v) -> void { add(v, 1); for (int i = 1; i < (int)A[v].size(); i++) set_sub(set_sub, A[v][i].first); ans += cur * dist_parent[v]; int p = heavy_parent[v]; if (p != int(u32(-1))) self(self, p); else reset_sub(reset_sub, v); }; for (int i = 0; i < N; i++) { if (A[i].empty()) climb(climb, i); } cout << ans * 2 << "\n"; } int main(){ ios::sync_with_stdio(false); cin.tie(nullptr); int T = 1; while (T--) solve(); return 0; }