結果
問題 | No.2116 Making Forest Hard |
ユーザー |
![]() |
提出日時 | 2022-10-05 09:56:20 |
言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,074 ms / 8,000 ms |
コード長 | 12,188 bytes |
コンパイル時間 | 3,130 ms |
コンパイル使用メモリ | 202,248 KB |
実行使用メモリ | 345,088 KB |
最終ジャッジ日時 | 2024-06-25 01:20:47 |
合計ジャッジ時間 | 44,810 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 53 |
コンパイルメッセージ
main.cpp:74:16: warning: inline variables are only available with '-std=c++17' or '-std=gnu++17' [-Wc++17-extensions] 74 | static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min(); | ^~~~~~ main.cpp:75:16: warning: inline variables are only available with '-std=c++17' or '-std=gnu++17' [-Wc++17-extensions] 75 | static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max(); | ^~~~~~
ソースコード
#include <bits/stdc++.h>int ri() {int n;scanf("%d", &n);return n;}template<int mod>struct ModInt{int x;ModInt () : x(0) {}ModInt (int64_t x) : x(x >= 0 ? x % mod : (mod - -x % mod) % mod) {}ModInt &operator += (const ModInt &p){if ((x += p.x) >= mod) x -= mod;return *this;}ModInt &operator -= (const ModInt &p) {if ((x += mod - p.x) >= mod) x -= mod;return *this;}ModInt &operator *= (const ModInt &p) {x = (int64_t) x * p.x % mod;return *this;}ModInt &operator /= (const ModInt &p) {*this *= p.inverse();return *this;}ModInt &operator ^= (int64_t p) {ModInt res = 1;for (; p; p >>= 1) {if (p & 1) res *= *this;*this *= *this;}return *this = res;}ModInt operator - () const { return ModInt(-x); }ModInt operator + (const ModInt &p) const { return ModInt(*this) += p; }ModInt operator - (const ModInt &p) const { return ModInt(*this) -= p; }ModInt operator * (const ModInt &p) const { return ModInt(*this) *= p; }ModInt operator / (const ModInt &p) const { return ModInt(*this) /= p; }ModInt operator ^ (int64_t p) const { return ModInt(*this) ^= p; }bool operator == (const ModInt &p) const { return x == p.x; }bool operator != (const ModInt &p) const { return x != p.x; }explicit operator int() const { return x; }ModInt &operator = (const int p) { x = p; return *this;}ModInt inverse() const {int a = x, b = mod, u = 1, v = 0, t;while (b > 0) {t = a / b;a -= t * b;std::swap(a, b);u -= t * v;std::swap(u, v);}return ModInt(u);}friend std::ostream & operator << (std::ostream &stream, const ModInt<mod> &p) {return stream << p.x;}friend std::istream & operator >> (std::istream &stream, ModInt<mod> &a) {int64_t x;stream >> x;a = ModInt<mod>(x);return stream;}};typedef ModInt<998244353> mint;#define MAX_NODE (200000 * 20)template<typename key_t, typename value_t, typename agg_t, typename op_t> struct avl_map {static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min();static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max();struct Node {Node *l;Node *r;int size;int height;key_t key;value_t val;agg_t agg;op_t op;Node *fetch() {if (l != avl_map::NONE) l->flush();if (r != avl_map::NONE) r->flush();size = 1 + l->size + r->size;height = 1 + std::max(l->height, r->height);agg = l->agg + ((agg_t) val) + r->agg;return this;}Node *flush() {if (op != op_t()) {op.apply_val(val);op.apply_agg(agg);if (l) l->op = l->op * op;if (r) r->op = r->op * op;op = op_t();}return this;}Node *rotate_l() {Node *new_root = r->flush();r = new_root->l;new_root->l = this;return fetch(), new_root->fetch();}Node *rotate_r() {Node *new_root = l->flush();l = new_root->r;new_root->r = this;return fetch(), new_root->fetch();}int height_diff() { return l->height - r->height; }Node *balance() {int dif = height_diff();if (dif == 2) {if (l->flush()->height_diff() < 0) l = l->rotate_l();return rotate_r();} else if (dif == -2) {if (r->flush()->height_diff() > 0) r = r->rotate_r();return rotate_l();} else return fetch();}};static Node *nodes, *removed_tmp, *NONE;Node *root;static int head;avl_map (Node *root) : avl_map() { this->root = root; }avl_map () {if (!head) NONE = nodes = new Node[MAX_NODE], nodes[head++] = {NONE, NONE, 0, 0, key_t(), value_t(), agg_t(), op_t()};root = NONE;}template<class T> static Node *insert(Node *node, const key_t &key, const value_t &val, const T &func) {if (node == NONE) return &(nodes[head++] = {NONE, NONE, 1, 1, key, val, val, op_t()});node->flush();if (key < node->key) node->l = insert(node->l, key, val, func);else if (key > node->key) node->r = insert(node->r, key, val, func);else func(node, val);return node->balance();}static Node *remove_rightmost(Node *node) {node->flush();if (node->r != NONE) {node->r = remove_rightmost(node->r);return node->balance();} else return (removed_tmp = node)->l;}static Node *remove(Node *node, const key_t &key) {if (node == NONE) return node;node->flush();if (key < node->key) node->l = remove(node->l, key);else if (key > node->key) node->r = remove(node->r, key);else {if (node->l == NONE) return node->r;node->l = remove_rightmost(node->l);removed_tmp->l = node->l;removed_tmp->r = node->r;return removed_tmp->balance();}return node->balance();}static Node *merge_with_root(Node *l, Node *m, Node *r) {int dif = l->height - r->height;if (-2 <= dif && dif <= 2) {m->l = l;m->r = r;return m->balance();} else if (dif > 0) {l->flush();l->r = merge_with_root(l->r, m, r);return l->balance();} else {r->flush();r->l = merge_with_root(l, m, r->l);return r->balance();}}static Node *merge(Node *l, Node *r) {if (l == NONE) return r;if (r == NONE) return l;l = remove_rightmost(l);return merge_with_root(l, removed_tmp, r);}static std::pair<Node *, Node *> split(Node *node, const key_t &key) {if (node == NONE) return {NONE, NONE};node->flush();Node *l = node->l;Node *r = node->r;node->l = node->r = NONE;if (key < node->key) {auto tmp = split(l, key);return {tmp.first, merge_with_root(tmp.second, node, r)};} else if (key > node->key) {auto tmp = split(r, key);return {merge_with_root(l, node, tmp.first), tmp.second};} else return {l, merge_with_root(NONE, node->fetch(), r)};}static agg_t aggregate(Node *node, const key_t &key_l, const key_t &key_r) {agg_t res{};if (node == NONE || key_r <= key_l) return res;node->flush();if (key_l == KEY_MIN && key_r == KEY_MAX) return node->agg;if (key_l < node->key) res = res + aggregate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r);if (key_l <= node->key && key_r > node->key) res = res + (agg_t) node->val;if (key_r > node->key + 1) res = res + aggregate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r);return res;}static void operate(Node *node, const key_t &key_l, const key_t &key_r, const op_t &op) {if (node == NONE || key_r <= key_l) return;node->flush();if (key_l == KEY_MIN && key_r == KEY_MAX) {node->op = op;node->flush();return;}if (key_l < node->key) operate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r, op);if (key_l <= node->key && key_r > node->key) op.apply_val(node->val);if (key_r > node->key + 1) operate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r, op);node->fetch();}static void dump(Node *node, std::vector<std::pair<key_t, value_t> > &res) {node->flush();if (node->l != NONE) dump(node->l, res);res.push_back({node->key, node->val});if (node->r != NONE) dump(node->r, res);}// member functionsint size() const { return root->size; }void assign(const key_t &key, const value_t &val) {root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = val; });}void add(const key_t &key, const value_t &val) {root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = ((agg_t) node->val) + (agg_t) val; });}void remove(const key_t &key) { root = remove(root, key); }static avl_map merge(const std::vector<avl_map> &a) {Node *res = a[0].root;for (int i = 1; i < (int) a.size(); i++) res = merge(res, a[i].root);return avl_map(res);}std::pair<avl_map, avl_map> split(const key_t &key) {auto tmp = split(root, key);return {avl_map(tmp.first), avl_map(tmp.second)};}agg_t aggregate(const key_t &l, const key_t &r) { return aggregate(root, l, r); }agg_t aggregate_left(const key_t &r) { return aggregate(root, KEY_MIN, r); }agg_t aggregate_right(const key_t &l) { return aggregate(root, l, KEY_MAX); }void operate(const key_t &l, const key_t &r, const op_t &op) { operate(root, l, r, op); }void operate_left(const key_t &r, const op_t &op) { operate(root, KEY_MIN, r, op); }void operate_right(const key_t &l, const op_t &op) { operate(root, l, KEY_MAX, op); }std::vector<std::pair<key_t, value_t> > dump() const {std::vector<std::pair<key_t, value_t> > res;dump(root, res);return res;}static void reset() { if (head) head = 1; }void clear() { root = NONE; }};template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node*avl_map<key_t, value_t, agg_t, op_t>::nodes;template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node*avl_map<key_t, value_t, agg_t, op_t>::removed_tmp;template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node*avl_map<key_t, value_t, agg_t, op_t>::NONE;template<typename key_t, typename value_t, typename agg_t, typename op_t> int avl_map<key_t, value_t, agg_t, op_t>::head = 0;template<typename T, size_t n> std::array<T, n> array_add(const std::array<T, n> &a, const std::array<T, n> &b) {std::array<T, n> res;for (int i = 0; i < n; i++) res[i] = a[i] + b[i];return res;}struct value_t {std::array<mint, 4> vec{};value_t operator + (const value_t &rhs) {std::array<mint, 4> res;for (int i = 0; i < 4; i++) res[i] = vec[i] + rhs.vec[i];return {res};}};struct op_t {std::array<std::array<mint, 2>, 2> mat = {{{1, 0}, {0, 1}}};static op_t zero() { return op_t{{{{0, 0}, {0, 0}}}}; };template<class T> void apply(T &val) const {std::array<mint, 4> tmp{};for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 0] += val.vec[i + 0] * mat[i][j];for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 2] += val.vec[i + 2] * mat[i][j];val.vec = tmp;}void apply_val(value_t &val) const { apply(val); }void apply_agg(value_t &sum) const { apply(sum); }bool operator != (const op_t &rhs) const { return mat != rhs.mat; }op_t operator * (const op_t &rhs) const {op_t res;for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++)res.mat[i][j] = mat[i][0] * rhs.mat[0][j] + mat[i][1] * rhs.mat[1][j];return res;}};using map = avl_map<int, value_t, value_t, op_t>;std::vector<std::vector<int> > hen;void merge(map &lhs, map &rhs) {if (lhs.size() < rhs.size()) std::swap(lhs, rhs);std::vector<std::pair<int, std::array<mint, 4> > > add_list;auto dump = rhs.dump();for (int i = dump.size(); i--; ) {int key = dump[i].first;auto val = dump[i].second.vec;auto cur_sum = lhs.aggregate_left(key).vec;mint r0 = val[0] * cur_sum[0];mint r1 = val[0] * cur_sum[1] + val[1] * cur_sum[0];mint r2 = r0 * key;mint r3 = r1 * key;add_list.push_back({key, {r0, r1, r2, r3}});}mint sum[4] = { 0 };int last_key = map::KEY_MIN;for (int i = 0; i < (int) dump.size(); i++) {int key = dump[i].first;auto val = dump[i].second.vec;lhs.operate(last_key, key, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});last_key = key;for (int j = 0; j < 4; j++) sum[j] += val[j];}lhs.operate(last_key, map::KEY_MAX, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});for (auto i : add_list) lhs.add(i.first, value_t{i.second});}std::vector<int> cnt;std::vector<map> dp;std::vector<int> a;std::vector<mint> power2;int n;mint res = 0;void dfs(int i, int prev) {dp[i].assign(a[i], value_t({1, 1, a[i], a[i]}));for (auto j : hen[i]) if (j != prev) {dfs(j, i);cnt[i] += cnt[j];dp[j].assign(-1, value_t({power2[cnt[j] - 1], 0, 0, -power2[cnt[j] - 1]}));merge(dp[i], dp[j]);}res += dp[i].aggregate(map::KEY_MIN, map::KEY_MAX).vec[3] * power2[n - cnt[i] - (prev != -1)];}int main() {n = ri();a.resize(n);for (auto &i : a) i = ri();hen.resize(n);for (int i = 1; i < n; i++) {int x = ri() - 1;int y = ri() - 1;hen[x].push_back(y);hen[y].push_back(x);}dp.resize(n);cnt.resize(n, 1);power2.resize(n + 1, 1);for (int i = 1; i <= n; i++) power2[i] = power2[i - 1] + power2[i - 1];res = 0;dfs(0, -1);std::cout << res << std::endl;return 0;}