結果
問題 | No.2115 Making Forest Easy |
ユーザー | QCFium |
提出日時 | 2022-08-22 21:16:50 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 167 ms / 2,000 ms |
コード長 | 12,188 bytes |
コンパイル時間 | 2,707 ms |
コンパイル使用メモリ | 237,728 KB |
実行使用メモリ | 317,580 KB |
最終ジャッジ日時 | 2024-06-22 20:44:16 |
合計ジャッジ時間 | 11,842 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 146 ms
316,036 KB |
testcase_01 | AC | 167 ms
316,112 KB |
testcase_02 | AC | 161 ms
317,580 KB |
testcase_03 | AC | 155 ms
316,428 KB |
testcase_04 | AC | 159 ms
317,104 KB |
testcase_05 | AC | 154 ms
316,984 KB |
testcase_06 | AC | 157 ms
317,536 KB |
testcase_07 | AC | 152 ms
316,484 KB |
testcase_08 | AC | 155 ms
317,576 KB |
testcase_09 | AC | 155 ms
316,540 KB |
testcase_10 | AC | 158 ms
317,544 KB |
testcase_11 | AC | 153 ms
316,484 KB |
testcase_12 | AC | 156 ms
316,556 KB |
testcase_13 | AC | 153 ms
316,428 KB |
testcase_14 | AC | 157 ms
316,500 KB |
testcase_15 | AC | 154 ms
316,516 KB |
testcase_16 | AC | 152 ms
316,500 KB |
testcase_17 | AC | 143 ms
316,160 KB |
testcase_18 | AC | 142 ms
316,164 KB |
testcase_19 | AC | 147 ms
316,816 KB |
testcase_20 | AC | 152 ms
316,468 KB |
testcase_21 | AC | 140 ms
316,136 KB |
testcase_22 | AC | 141 ms
316,044 KB |
testcase_23 | AC | 151 ms
317,040 KB |
testcase_24 | AC | 160 ms
316,448 KB |
testcase_25 | AC | 146 ms
316,424 KB |
testcase_26 | AC | 157 ms
316,496 KB |
testcase_27 | AC | 147 ms
316,220 KB |
testcase_28 | AC | 152 ms
316,468 KB |
testcase_29 | AC | 140 ms
316,280 KB |
testcase_30 | AC | 155 ms
316,472 KB |
testcase_31 | AC | 158 ms
317,244 KB |
testcase_32 | AC | 156 ms
317,516 KB |
testcase_33 | AC | 156 ms
317,576 KB |
testcase_34 | AC | 142 ms
316,240 KB |
testcase_35 | AC | 143 ms
316,412 KB |
testcase_36 | AC | 157 ms
317,512 KB |
testcase_37 | AC | 147 ms
316,148 KB |
testcase_38 | AC | 142 ms
316,140 KB |
testcase_39 | AC | 145 ms
316,656 KB |
testcase_40 | AC | 153 ms
316,468 KB |
testcase_41 | AC | 153 ms
316,400 KB |
testcase_42 | AC | 155 ms
316,432 KB |
testcase_43 | AC | 145 ms
316,112 KB |
testcase_44 | AC | 150 ms
316,304 KB |
testcase_45 | AC | 154 ms
316,408 KB |
testcase_46 | AC | 153 ms
316,496 KB |
testcase_47 | AC | 143 ms
316,276 KB |
testcase_48 | AC | 157 ms
317,180 KB |
testcase_49 | AC | 155 ms
317,256 KB |
testcase_50 | AC | 147 ms
316,344 KB |
testcase_51 | AC | 147 ms
316,288 KB |
ソースコード
#include <bits/stdc++.h> int ri() { int n; scanf("%d", &n); return n; } template<int mod> struct ModInt{ int x; ModInt () : x(0) {} ModInt (int64_t x) : x(x >= 0 ? x % mod : (mod - -x % mod) % mod) {} ModInt &operator += (const ModInt &p){ if ((x += p.x) >= mod) x -= mod; return *this; } ModInt &operator -= (const ModInt &p) { if ((x += mod - p.x) >= mod) x -= mod; return *this; } ModInt &operator *= (const ModInt &p) { x = (int64_t) x * p.x % mod; return *this; } ModInt &operator /= (const ModInt &p) { *this *= p.inverse(); return *this; } ModInt &operator ^= (int64_t p) { ModInt res = 1; for (; p; p >>= 1) { if (p & 1) res *= *this; *this *= *this; } return *this = res; } ModInt operator - () const { return ModInt(-x); } ModInt operator + (const ModInt &p) const { return ModInt(*this) += p; } ModInt operator - (const ModInt &p) const { return ModInt(*this) -= p; } ModInt operator * (const ModInt &p) const { return ModInt(*this) *= p; } ModInt operator / (const ModInt &p) const { return ModInt(*this) /= p; } ModInt operator ^ (int64_t p) const { return ModInt(*this) ^= p; } bool operator == (const ModInt &p) const { return x == p.x; } bool operator != (const ModInt &p) const { return x != p.x; } explicit operator int() const { return x; } ModInt &operator = (const int p) { x = p; return *this;} ModInt inverse() const { int a = x, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; a -= t * b; std::swap(a, b); u -= t * v; std::swap(u, v); } return ModInt(u); } friend std::ostream & operator << (std::ostream &stream, const ModInt<mod> &p) { return stream << p.x; } friend std::istream & operator >> (std::istream &stream, ModInt<mod> &a) { int64_t x; stream >> x; a = ModInt<mod>(x); return stream; } }; typedef ModInt<998244353> mint; #define MAX_NODE (200000 * 20) template<typename key_t, typename value_t, typename agg_t, typename op_t> struct avl_map { static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min(); static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max(); struct Node { Node *l; Node *r; int size; int height; key_t key; value_t val; agg_t agg; op_t op; Node *fetch() { if (l != avl_map::NONE) l->flush(); if (r != avl_map::NONE) r->flush(); size = 1 + l->size + r->size; height = 1 + std::max(l->height, r->height); agg = l->agg + ((agg_t) val) + r->agg; return this; } Node *flush() { if (op != op_t()) { op.apply_val(val); op.apply_agg(agg); if (l) l->op = l->op * op; if (r) r->op = r->op * op; op = op_t(); } return this; } Node *rotate_l() { Node *new_root = r->flush(); r = new_root->l; new_root->l = this; return fetch(), new_root->fetch(); } Node *rotate_r() { Node *new_root = l->flush(); l = new_root->r; new_root->r = this; return fetch(), new_root->fetch(); } int height_diff() { return l->height - r->height; } Node *balance() { int dif = height_diff(); if (dif == 2) { if (l->flush()->height_diff() < 0) l = l->rotate_l(); return rotate_r(); } else if (dif == -2) { if (r->flush()->height_diff() > 0) r = r->rotate_r(); return rotate_l(); } else return fetch(); } }; static Node *nodes, *removed_tmp, *NONE; Node *root; static int head; avl_map (Node *root) : avl_map() { this->root = root; } avl_map () { if (!head) NONE = nodes = new Node[MAX_NODE], nodes[head++] = {NONE, NONE, 0, 0, key_t(), value_t(), agg_t(), op_t()}; root = NONE; } template<class T> static Node *insert(Node *node, const key_t &key, const value_t &val, const T &func) { if (node == NONE) return &(nodes[head++] = {NONE, NONE, 1, 1, key, val, val, op_t()}); node->flush(); if (key < node->key) node->l = insert(node->l, key, val, func); else if (key > node->key) node->r = insert(node->r, key, val, func); else func(node, val); return node->balance(); } static Node *remove_rightmost(Node *node) { node->flush(); if (node->r != NONE) { node->r = remove_rightmost(node->r); return node->balance(); } else return (removed_tmp = node)->l; } static Node *remove(Node *node, const key_t &key) { if (node == NONE) return node; node->flush(); if (key < node->key) node->l = remove(node->l, key); else if (key > node->key) node->r = remove(node->r, key); else { if (node->l == NONE) return node->r; node->l = remove_rightmost(node->l); removed_tmp->l = node->l; removed_tmp->r = node->r; return removed_tmp->balance(); } return node->balance(); } static Node *merge_with_root(Node *l, Node *m, Node *r) { int dif = l->height - r->height; if (-2 <= dif && dif <= 2) { m->l = l; m->r = r; return m->balance(); } else if (dif > 0) { l->flush(); l->r = merge_with_root(l->r, m, r); return l->balance(); } else { r->flush(); r->l = merge_with_root(l, m, r->l); return r->balance(); } } static Node *merge(Node *l, Node *r) { if (l == NONE) return r; if (r == NONE) return l; l = remove_rightmost(l); return merge_with_root(l, removed_tmp, r); } static std::pair<Node *, Node *> split(Node *node, const key_t &key) { if (node == NONE) return {NONE, NONE}; node->flush(); Node *l = node->l; Node *r = node->r; node->l = node->r = NONE; if (key < node->key) { auto tmp = split(l, key); return {tmp.first, merge_with_root(tmp.second, node, r)}; } else if (key > node->key) { auto tmp = split(r, key); return {merge_with_root(l, node, tmp.first), tmp.second}; } else return {l, merge_with_root(NONE, node->fetch(), r)}; } static agg_t aggregate(Node *node, const key_t &key_l, const key_t &key_r) { agg_t res{}; if (node == NONE || key_r <= key_l) return res; node->flush(); if (key_l == KEY_MIN && key_r == KEY_MAX) return node->agg; if (key_l < node->key) res = res + aggregate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r); if (key_l <= node->key && key_r > node->key) res = res + (agg_t) node->val; if (key_r > node->key + 1) res = res + aggregate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r); return res; } static void operate(Node *node, const key_t &key_l, const key_t &key_r, const op_t &op) { if (node == NONE || key_r <= key_l) return; node->flush(); if (key_l == KEY_MIN && key_r == KEY_MAX) { node->op = op; node->flush(); return; } if (key_l < node->key) operate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r, op); if (key_l <= node->key && key_r > node->key) op.apply_val(node->val); if (key_r > node->key + 1) operate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r, op); node->fetch(); } static void dump(Node *node, std::vector<std::pair<key_t, value_t> > &res) { node->flush(); if (node->l != NONE) dump(node->l, res); res.push_back({node->key, node->val}); if (node->r != NONE) dump(node->r, res); } // member functions int size() const { return root->size; } void assign(const key_t &key, const value_t &val) { root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = val; }); } void add(const key_t &key, const value_t &val) { root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = ((agg_t) node->val) + (agg_t) val; }); } void remove(const key_t &key) { root = remove(root, key); } static avl_map merge(const std::vector<avl_map> &a) { Node *res = a[0].root; for (int i = 1; i < (int) a.size(); i++) res = merge(res, a[i].root); return avl_map(res); } std::pair<avl_map, avl_map> split(const key_t &key) { auto tmp = split(root, key); return {avl_map(tmp.first), avl_map(tmp.second)}; } agg_t aggregate(const key_t &l, const key_t &r) { return aggregate(root, l, r); } agg_t aggregate_left(const key_t &r) { return aggregate(root, KEY_MIN, r); } agg_t aggregate_right(const key_t &l) { return aggregate(root, l, KEY_MAX); } void operate(const key_t &l, const key_t &r, const op_t &op) { operate(root, l, r, op); } void operate_left(const key_t &r, const op_t &op) { operate(root, KEY_MIN, r, op); } void operate_right(const key_t &l, const op_t &op) { operate(root, l, KEY_MAX, op); } std::vector<std::pair<key_t, value_t> > dump() const { std::vector<std::pair<key_t, value_t> > res; dump(root, res); return res; } static void reset() { if (head) head = 1; } void clear() { root = NONE; } }; template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node *avl_map<key_t, value_t, agg_t, op_t>::nodes; template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node *avl_map<key_t, value_t, agg_t, op_t>::removed_tmp; template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node *avl_map<key_t, value_t, agg_t, op_t>::NONE; template<typename key_t, typename value_t, typename agg_t, typename op_t> int avl_map<key_t, value_t, agg_t, op_t>::head = 0; template<typename T, size_t n> std::array<T, n> array_add(const std::array<T, n> &a, const std::array<T, n> &b) { std::array<T, n> res; for (int i = 0; i < n; i++) res[i] = a[i] + b[i]; return res; } struct value_t { std::array<mint, 4> vec{}; value_t operator + (const value_t &rhs) { std::array<mint, 4> res; for (int i = 0; i < 4; i++) res[i] = vec[i] + rhs.vec[i]; return {res}; } }; struct op_t { std::array<std::array<mint, 2>, 2> mat = {{{1, 0}, {0, 1}}}; static op_t zero() { return op_t{{{{0, 0}, {0, 0}}}}; }; template<class T> void apply(T &val) const { std::array<mint, 4> tmp{}; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 0] += val.vec[i + 0] * mat[i][j]; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 2] += val.vec[i + 2] * mat[i][j]; val.vec = tmp; } void apply_val(value_t &val) const { apply(val); } void apply_agg(value_t &sum) const { apply(sum); } bool operator != (const op_t &rhs) const { return mat != rhs.mat; } op_t operator * (const op_t &rhs) const { op_t res; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) res.mat[i][j] = mat[i][0] * rhs.mat[0][j] + mat[i][1] * rhs.mat[1][j]; return res; } }; using map = avl_map<int, value_t, value_t, op_t>; std::vector<std::vector<int> > hen; void merge(map &lhs, map &rhs) { if (lhs.size() < rhs.size()) std::swap(lhs, rhs); std::vector<std::pair<int, std::array<mint, 4> > > add_list; auto dump = rhs.dump(); for (int i = dump.size(); i--; ) { int key = dump[i].first; auto val = dump[i].second.vec; auto cur_sum = lhs.aggregate_left(key).vec; mint r0 = val[0] * cur_sum[0]; mint r1 = val[0] * cur_sum[1] + val[1] * cur_sum[0]; mint r2 = r0 * key; mint r3 = r1 * key; add_list.push_back({key, {r0, r1, r2, r3}}); } mint sum[4] = { 0 }; int last_key = map::KEY_MIN; for (int i = 0; i < (int) dump.size(); i++) { int key = dump[i].first; auto val = dump[i].second.vec; lhs.operate(last_key, key, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}}); last_key = key; for (int j = 0; j < 4; j++) sum[j] += val[j]; } lhs.operate(last_key, map::KEY_MAX, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}}); for (auto i : add_list) lhs.add(i.first, value_t{i.second}); } std::vector<int> cnt; std::vector<map> dp; std::vector<int> a; std::vector<mint> power2; int n; mint res = 0; void dfs(int i, int prev) { dp[i].assign(a[i], value_t({1, 1, a[i], a[i]})); for (auto j : hen[i]) if (j != prev) { dfs(j, i); cnt[i] += cnt[j]; dp[j].assign(-1, value_t({power2[cnt[j] - 1], 0, 0, -power2[cnt[j] - 1]})); merge(dp[i], dp[j]); } res += dp[i].aggregate(map::KEY_MIN, map::KEY_MAX).vec[3] * power2[n - cnt[i] - (prev != -1)]; } int main() { n = ri(); a.resize(n); for (auto &i : a) i = ri(); hen.resize(n); for (int i = 1; i < n; i++) { int x = ri() - 1; int y = ri() - 1; hen[x].push_back(y); hen[y].push_back(x); } dp.resize(n); cnt.resize(n, 1); power2.resize(n + 1, 1); for (int i = 1; i <= n; i++) power2[i] = power2[i - 1] + power2[i - 1]; res = 0; dfs(0, -1); std::cout << res << std::endl; return 0; }