結果

問題 No.2116 Making Forest Hard
ユーザー QCFium
提出日時 2022-10-05 09:56:20
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,074 ms / 8,000 ms
コード長 12,188 bytes
コンパイル時間 3,130 ms
コンパイル使用メモリ 202,248 KB
実行使用メモリ 345,088 KB
最終ジャッジ日時 2024-06-25 01:20:47
合計ジャッジ時間 44,810 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 53
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp:74:16: warning: inline variables are only available with '-std=c++17' or '-std=gnu++17' [-Wc++17-extensions]
   74 |         static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min();
      |                ^~~~~~
main.cpp:75:16: warning: inline variables are only available with '-std=c++17' or '-std=gnu++17' [-Wc++17-extensions]
   75 |         static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max();
      |                ^~~~~~

ソースコード

diff #
プレゼンテーションモードにする

#include <bits/stdc++.h>
int ri() {
int n;
scanf("%d", &n);
return n;
}
template<int mod>
struct ModInt{
int x;
ModInt () : x(0) {}
ModInt (int64_t x) : x(x >= 0 ? x % mod : (mod - -x % mod) % mod) {}
ModInt &operator += (const ModInt &p){
if ((x += p.x) >= mod) x -= mod;
return *this;
}
ModInt &operator -= (const ModInt &p) {
if ((x += mod - p.x) >= mod) x -= mod;
return *this;
}
ModInt &operator *= (const ModInt &p) {
x = (int64_t) x * p.x % mod;
return *this;
}
ModInt &operator /= (const ModInt &p) {
*this *= p.inverse();
return *this;
}
ModInt &operator ^= (int64_t p) {
ModInt res = 1;
for (; p; p >>= 1) {
if (p & 1) res *= *this;
*this *= *this;
}
return *this = res;
}
ModInt operator - () const { return ModInt(-x); }
ModInt operator + (const ModInt &p) const { return ModInt(*this) += p; }
ModInt operator - (const ModInt &p) const { return ModInt(*this) -= p; }
ModInt operator * (const ModInt &p) const { return ModInt(*this) *= p; }
ModInt operator / (const ModInt &p) const { return ModInt(*this) /= p; }
ModInt operator ^ (int64_t p) const { return ModInt(*this) ^= p; }
bool operator == (const ModInt &p) const { return x == p.x; }
bool operator != (const ModInt &p) const { return x != p.x; }
explicit operator int() const { return x; }
ModInt &operator = (const int p) { x = p; return *this;}
ModInt inverse() const {
int a = x, b = mod, u = 1, v = 0, t;
while (b > 0) {
t = a / b;
a -= t * b;
std::swap(a, b);
u -= t * v;
std::swap(u, v);
}
return ModInt(u);
}
friend std::ostream & operator << (std::ostream &stream, const ModInt<mod> &p) {
return stream << p.x;
}
friend std::istream & operator >> (std::istream &stream, ModInt<mod> &a) {
int64_t x;
stream >> x;
a = ModInt<mod>(x);
return stream;
}
};
typedef ModInt<998244353> mint;
#define MAX_NODE (200000 * 20)
template<typename key_t, typename value_t, typename agg_t, typename op_t> struct avl_map {
static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min();
static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max();
struct Node {
Node *l;
Node *r;
int size;
int height;
key_t key;
value_t val;
agg_t agg;
op_t op;
Node *fetch() {
if (l != avl_map::NONE) l->flush();
if (r != avl_map::NONE) r->flush();
size = 1 + l->size + r->size;
height = 1 + std::max(l->height, r->height);
agg = l->agg + ((agg_t) val) + r->agg;
return this;
}
Node *flush() {
if (op != op_t()) {
op.apply_val(val);
op.apply_agg(agg);
if (l) l->op = l->op * op;
if (r) r->op = r->op * op;
op = op_t();
}
return this;
}
Node *rotate_l() {
Node *new_root = r->flush();
r = new_root->l;
new_root->l = this;
return fetch(), new_root->fetch();
}
Node *rotate_r() {
Node *new_root = l->flush();
l = new_root->r;
new_root->r = this;
return fetch(), new_root->fetch();
}
int height_diff() { return l->height - r->height; }
Node *balance() {
int dif = height_diff();
if (dif == 2) {
if (l->flush()->height_diff() < 0) l = l->rotate_l();
return rotate_r();
} else if (dif == -2) {
if (r->flush()->height_diff() > 0) r = r->rotate_r();
return rotate_l();
} else return fetch();
}
};
static Node *nodes, *removed_tmp, *NONE;
Node *root;
static int head;
avl_map (Node *root) : avl_map() { this->root = root; }
avl_map () {
if (!head) NONE = nodes = new Node[MAX_NODE], nodes[head++] = {NONE, NONE, 0, 0, key_t(), value_t(), agg_t(), op_t()};
root = NONE;
}
template<class T> static Node *insert(Node *node, const key_t &key, const value_t &val, const T &func) {
if (node == NONE) return &(nodes[head++] = {NONE, NONE, 1, 1, key, val, val, op_t()});
node->flush();
if (key < node->key) node->l = insert(node->l, key, val, func);
else if (key > node->key) node->r = insert(node->r, key, val, func);
else func(node, val);
return node->balance();
}
static Node *remove_rightmost(Node *node) {
node->flush();
if (node->r != NONE) {
node->r = remove_rightmost(node->r);
return node->balance();
} else return (removed_tmp = node)->l;
}
static Node *remove(Node *node, const key_t &key) {
if (node == NONE) return node;
node->flush();
if (key < node->key) node->l = remove(node->l, key);
else if (key > node->key) node->r = remove(node->r, key);
else {
if (node->l == NONE) return node->r;
node->l = remove_rightmost(node->l);
removed_tmp->l = node->l;
removed_tmp->r = node->r;
return removed_tmp->balance();
}
return node->balance();
}
static Node *merge_with_root(Node *l, Node *m, Node *r) {
int dif = l->height - r->height;
if (-2 <= dif && dif <= 2) {
m->l = l;
m->r = r;
return m->balance();
} else if (dif > 0) {
l->flush();
l->r = merge_with_root(l->r, m, r);
return l->balance();
} else {
r->flush();
r->l = merge_with_root(l, m, r->l);
return r->balance();
}
}
static Node *merge(Node *l, Node *r) {
if (l == NONE) return r;
if (r == NONE) return l;
l = remove_rightmost(l);
return merge_with_root(l, removed_tmp, r);
}
static std::pair<Node *, Node *> split(Node *node, const key_t &key) {
if (node == NONE) return {NONE, NONE};
node->flush();
Node *l = node->l;
Node *r = node->r;
node->l = node->r = NONE;
if (key < node->key) {
auto tmp = split(l, key);
return {tmp.first, merge_with_root(tmp.second, node, r)};
} else if (key > node->key) {
auto tmp = split(r, key);
return {merge_with_root(l, node, tmp.first), tmp.second};
} else return {l, merge_with_root(NONE, node->fetch(), r)};
}
static agg_t aggregate(Node *node, const key_t &key_l, const key_t &key_r) {
agg_t res{};
if (node == NONE || key_r <= key_l) return res;
node->flush();
if (key_l == KEY_MIN && key_r == KEY_MAX) return node->agg;
if (key_l < node->key) res = res + aggregate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r);
if (key_l <= node->key && key_r > node->key) res = res + (agg_t) node->val;
if (key_r > node->key + 1) res = res + aggregate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r);
return res;
}
static void operate(Node *node, const key_t &key_l, const key_t &key_r, const op_t &op) {
if (node == NONE || key_r <= key_l) return;
node->flush();
if (key_l == KEY_MIN && key_r == KEY_MAX) {
node->op = op;
node->flush();
return;
}
if (key_l < node->key) operate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r, op);
if (key_l <= node->key && key_r > node->key) op.apply_val(node->val);
if (key_r > node->key + 1) operate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r, op);
node->fetch();
}
static void dump(Node *node, std::vector<std::pair<key_t, value_t> > &res) {
node->flush();
if (node->l != NONE) dump(node->l, res);
res.push_back({node->key, node->val});
if (node->r != NONE) dump(node->r, res);
}
// member functions
int size() const { return root->size; }
void assign(const key_t &key, const value_t &val) {
root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = val; });
}
void add(const key_t &key, const value_t &val) {
root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = ((agg_t) node->val) + (agg_t) val; });
}
void remove(const key_t &key) { root = remove(root, key); }
static avl_map merge(const std::vector<avl_map> &a) {
Node *res = a[0].root;
for (int i = 1; i < (int) a.size(); i++) res = merge(res, a[i].root);
return avl_map(res);
}
std::pair<avl_map, avl_map> split(const key_t &key) {
auto tmp = split(root, key);
return {avl_map(tmp.first), avl_map(tmp.second)};
}
agg_t aggregate(const key_t &l, const key_t &r) { return aggregate(root, l, r); }
agg_t aggregate_left(const key_t &r) { return aggregate(root, KEY_MIN, r); }
agg_t aggregate_right(const key_t &l) { return aggregate(root, l, KEY_MAX); }
void operate(const key_t &l, const key_t &r, const op_t &op) { operate(root, l, r, op); }
void operate_left(const key_t &r, const op_t &op) { operate(root, KEY_MIN, r, op); }
void operate_right(const key_t &l, const op_t &op) { operate(root, l, KEY_MAX, op); }
std::vector<std::pair<key_t, value_t> > dump() const {
std::vector<std::pair<key_t, value_t> > res;
dump(root, res);
return res;
}
static void reset() { if (head) head = 1; }
void clear() { root = NONE; }
};
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
*avl_map<key_t, value_t, agg_t, op_t>::nodes;
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
*avl_map<key_t, value_t, agg_t, op_t>::removed_tmp;
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
*avl_map<key_t, value_t, agg_t, op_t>::NONE;
template<typename key_t, typename value_t, typename agg_t, typename op_t> int avl_map<key_t, value_t, agg_t, op_t>::head = 0;
template<typename T, size_t n> std::array<T, n> array_add(const std::array<T, n> &a, const std::array<T, n> &b) {
std::array<T, n> res;
for (int i = 0; i < n; i++) res[i] = a[i] + b[i];
return res;
}
struct value_t {
std::array<mint, 4> vec{};
value_t operator + (const value_t &rhs) {
std::array<mint, 4> res;
for (int i = 0; i < 4; i++) res[i] = vec[i] + rhs.vec[i];
return {res};
}
};
struct op_t {
std::array<std::array<mint, 2>, 2> mat = {{{1, 0}, {0, 1}}};
static op_t zero() { return op_t{{{{0, 0}, {0, 0}}}}; };
template<class T> void apply(T &val) const {
std::array<mint, 4> tmp{};
for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 0] += val.vec[i + 0] * mat[i][j];
for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 2] += val.vec[i + 2] * mat[i][j];
val.vec = tmp;
}
void apply_val(value_t &val) const { apply(val); }
void apply_agg(value_t &sum) const { apply(sum); }
bool operator != (const op_t &rhs) const { return mat != rhs.mat; }
op_t operator * (const op_t &rhs) const {
op_t res;
for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++)
res.mat[i][j] = mat[i][0] * rhs.mat[0][j] + mat[i][1] * rhs.mat[1][j];
return res;
}
};
using map = avl_map<int, value_t, value_t, op_t>;
std::vector<std::vector<int> > hen;
void merge(map &lhs, map &rhs) {
if (lhs.size() < rhs.size()) std::swap(lhs, rhs);
std::vector<std::pair<int, std::array<mint, 4> > > add_list;
auto dump = rhs.dump();
for (int i = dump.size(); i--; ) {
int key = dump[i].first;
auto val = dump[i].second.vec;
auto cur_sum = lhs.aggregate_left(key).vec;
mint r0 = val[0] * cur_sum[0];
mint r1 = val[0] * cur_sum[1] + val[1] * cur_sum[0];
mint r2 = r0 * key;
mint r3 = r1 * key;
add_list.push_back({key, {r0, r1, r2, r3}});
}
mint sum[4] = { 0 };
int last_key = map::KEY_MIN;
for (int i = 0; i < (int) dump.size(); i++) {
int key = dump[i].first;
auto val = dump[i].second.vec;
lhs.operate(last_key, key, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});
last_key = key;
for (int j = 0; j < 4; j++) sum[j] += val[j];
}
lhs.operate(last_key, map::KEY_MAX, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});
for (auto i : add_list) lhs.add(i.first, value_t{i.second});
}
std::vector<int> cnt;
std::vector<map> dp;
std::vector<int> a;
std::vector<mint> power2;
int n;
mint res = 0;
void dfs(int i, int prev) {
dp[i].assign(a[i], value_t({1, 1, a[i], a[i]}));
for (auto j : hen[i]) if (j != prev) {
dfs(j, i);
cnt[i] += cnt[j];
dp[j].assign(-1, value_t({power2[cnt[j] - 1], 0, 0, -power2[cnt[j] - 1]}));
merge(dp[i], dp[j]);
}
res += dp[i].aggregate(map::KEY_MIN, map::KEY_MAX).vec[3] * power2[n - cnt[i] - (prev != -1)];
}
int main() {
n = ri();
a.resize(n);
for (auto &i : a) i = ri();
hen.resize(n);
for (int i = 1; i < n; i++) {
int x = ri() - 1;
int y = ri() - 1;
hen[x].push_back(y);
hen[y].push_back(x);
}
dp.resize(n);
cnt.resize(n, 1);
power2.resize(n + 1, 1);
for (int i = 1; i <= n; i++) power2[i] = power2[i - 1] + power2[i - 1];
res = 0;
dfs(0, -1);
std::cout << res << std::endl;
return 0;
}
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0