結果
問題 | No.1333 Squared Sum |
ユーザー |
|
提出日時 | 2021-01-08 22:27:45 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 450 ms / 2,000 ms |
コード長 | 7,152 bytes |
コンパイル時間 | 2,217 ms |
コンパイル使用メモリ | 208,116 KB |
最終ジャッジ日時 | 2025-01-17 12:45:18 |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 44 |
ソースコード
//@formatter:off#include<bits/stdc++.h>#define overload4(_1,_2,_3,_4,name,...) name#define rep1(i,n) for (ll i = 0; i < ll(n); ++i)#define rep2(i,s,n) for (ll i = ll(s); i < ll(n); ++i)#define rep3(i,s,n,d) for(ll i = ll(s); i < ll(n); i+=d)#define rep(...) overload4(__VA_ARGS__,rep3,rep2,rep1)(__VA_ARGS__)#define rrep(i,n) for (ll i = ll(n)-1; i >= 0; i--)#define all(a) a.begin(),a.end()#define rall(a) a.rbegin(),a.rend()#define pb push_back#define eb emplace_back#define vi vector<int>#define vvi vector<vector<int>>#define vl vector<ll>#define vvl vector<vector<ll>>#define vd vector<double>#define vvd vector<vector<double>>#define vs vector<string>#define vc vector<char>#define vvc vector<vector<char>>#define vb vector<bool>#define vvb vector<vector<bool>>#define vp vector<P>#define vvp vector<vector<P>>#ifdef __LOCAL#define debug(...) { cout << #__VA_ARGS__; cout << ": "; print(__VA_ARGS__); }#else#define debug(...) void(0)#endif#define INT(...) int __VA_ARGS__;scan(__VA_ARGS__)#define LL(...) ll __VA_ARGS__;scan(__VA_ARGS__)#define STR(...) string __VA_ARGS__;scan(__VA_ARGS__)#define CHR(...) char __VA_ARGS__;scan(__VA_ARGS__)#define DBL(...) double __VA_ARGS__;scan(__VA_ARGS__)#define LD(...) ld __VA_ARGS__;scan(__VA_ARGS__)using namespace std;using ll = long long;using P = pair<int,int>;using LP = pair<ll,ll>;template<class S,class T> istream& operator>>(istream &is,pair<S,T> &p) { return is >> p.first >> p.second; }template<class S,class T> ostream& operator<<(ostream &os,const pair<S,T> &p) { return os<<'{'<<p.first<<","<<p.second<<'}'; }template<class T> istream& operator>>(istream &is,vector<T> &v) { for(T &t:v){is>>t;} return is; }template<class T> ostream& operator<<(ostream &os,const vector<T> &v) { os<<'[';rep(i,v.size())os<<v[i]<<(i==int(v.size()-1)?"":","); return os<<']';}void Yes(bool b) { cout << (b ? "Yes" : "No") << '\n'; }void YES(bool b) { cout << (b ? "YES" : "NO") << '\n'; }template<class T> void fin(T a) { cout << a << '\n'; exit(0); }template<class T> bool chmin(T& a,T b) {if(a > b){a = b; return true;} return false;}template<class T> bool chmax(T& a,T b) {if(a < b){a = b; return true;} return false;}void scan(){}template <class Head, class... Tail> void scan(Head& head, Tail&... tail){ cin >> head; scan(tail...); }template<class T> void print(const T& t){ cout << t << '\n'; }template <class Head, class... Tail> void print(const Head& head, const Tail&... tail){ cout<<head<<' '; print(tail...); }const int inf = 1001001001;const ll linf = 1001001001001001001;//@formatter:onconstexpr int mod = 1000000007;//constexpr int mod = 998244353;struct mint {ll x;constexpr mint(ll x = 0) : x((x % mod + mod) % mod) {}constexpr mint operator-() const { return mint(-x); }constexpr mint &operator+=(const mint &a) {if ((x += a.x) >= mod) x -= mod;return *this;}constexpr mint &operator++() { return *this += mint(1); }constexpr mint &operator-=(const mint &a) {if ((x += mod - a.x) >= mod) x -= mod;return *this;}constexpr mint &operator--() { return *this -= mint(1); }constexpr mint &operator*=(const mint &a) {(x *= a.x) %= mod;return *this;}constexpr mint operator+(const mint &a) const {mint res(*this);return res += a;}constexpr mint operator-(const mint &a) const {mint res(*this);return res -= a;}constexpr mint operator*(const mint &a) const {mint res(*this);return res *= a;}constexpr mint pow(ll t) const {mint res = mint(1), a(*this);while (t > 0) {if (t & 1) res *= a;t >>= 1;a *= a;}return res;}// for prime modconstexpr mint inv() const { return pow(mod - 2); }constexpr mint &operator/=(const mint &a) { return *this *= a.inv(); }constexpr mint operator/(const mint &a) const {mint res(*this);return res /= a;}};ostream &operator<<(ostream &os, const mint &a) { return os << a.x; }bool operator==(const mint &a, const mint &b) { return a.x == b.x; }bool operator!=(const mint &a, const mint &b) { return a.x != b.x; }bool operator==(const mint &a, const int &b) { return a.x == b; }bool operator!=(const mint &a, const int &b) { return a.x != b; }template<typename T, typename MERGE, typename ADDROOT, typename ADDEDGE>class rerooting {int n;vvp tree;T identity;MERGE merge;ADDROOT addRoot;ADDEDGE addEdge;vector<vector<T>> dp;vector<T> ans;T dfs(int v = 0, int p = -1) {T sum = identity;dp[v].resize(tree[v].size());rep(i, tree[v].size()) {auto[u, w] = tree[v][i];if (u == p) continue;dp[v][i] = dfs(u, v);sum = merge(sum, addEdge(dp[v][i], w));}return addRoot(sum);}void dfs2(T dpP, int v = 0, int p = -1) {int sz = tree[v].size();rep(i, sz) if (tree[v][i].first == p) dp[v][i] = dpP;vector<T> sumL(sz + 1, identity), sumR(sz + 1, identity);rep(i, sz) sumL[i + 1] = merge(sumL[i], addEdge(dp[v][i], tree[v][i].second));rrep(i, sz) sumR[i] = merge(sumR[i + 1], addEdge(dp[v][i], tree[v][i].second));ans[v] = addRoot(sumL[sz]);rep(i, sz) {auto[u, w] = tree[v][i];if (u == p) continue;T t = merge(sumL[i], sumR[i + 1]);dfs2(addRoot(t), u, v);}}void init() {dfs();dfs2(identity);}public:rerooting(int n, vvp tree, T identity, MERGE merge, ADDROOT addRoot, ADDEDGE addEdge): n(n), tree(tree), identity(identity), merge(merge), addRoot(addRoot), addEdge(addEdge), dp(n), ans(n) {init();};T get_ans(int i) {return ans[i];}};struct state {mint sq_sum, sum;int cnt;state(mint sq_sum = 0, mint sum = 0, int cnt = 0) : sq_sum(sq_sum), sum(sum), cnt(cnt) {}};int main() {ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);INT(n);vvp G(n);rep(_, n - 1) {INT(u, v, w);u--;v--;G[u].eb(v, w);G[v].eb(u, w);}auto merge = [](const state &a, const state &b) {return state(a.sq_sum + b.sq_sum, a.sum + b.sum, a.cnt + b.cnt);};auto addRoot = [](const state &a) {return state(a.sq_sum, a.sum, a.cnt + 1);};auto addEdge = [](const state &a, int weight) {return state(a.sq_sum + a.sum * weight * 2 + mint(weight) * weight * a.cnt, a.sum + mint(weight) * a.cnt,a.cnt);};rerooting<state, decltype(merge), decltype(addRoot), decltype(addEdge)> rt(n, G, state(0, 0, 0), merge, addRoot,addEdge);mint ans = 0;rep(i, n) ans += rt.get_ans(i).sq_sum;print(ans / 2);}