結果

問題 No.2115 Making Forest Easy
ユーザー QCFiumQCFium
提出日時 2022-08-22 21:16:50
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 167 ms / 2,000 ms
コード長 12,188 bytes
コンパイル時間 2,707 ms
コンパイル使用メモリ 237,728 KB
実行使用メモリ 317,580 KB
最終ジャッジ日時 2024-06-22 20:44:16
合計ジャッジ時間 11,842 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 146 ms
316,036 KB
testcase_01 AC 167 ms
316,112 KB
testcase_02 AC 161 ms
317,580 KB
testcase_03 AC 155 ms
316,428 KB
testcase_04 AC 159 ms
317,104 KB
testcase_05 AC 154 ms
316,984 KB
testcase_06 AC 157 ms
317,536 KB
testcase_07 AC 152 ms
316,484 KB
testcase_08 AC 155 ms
317,576 KB
testcase_09 AC 155 ms
316,540 KB
testcase_10 AC 158 ms
317,544 KB
testcase_11 AC 153 ms
316,484 KB
testcase_12 AC 156 ms
316,556 KB
testcase_13 AC 153 ms
316,428 KB
testcase_14 AC 157 ms
316,500 KB
testcase_15 AC 154 ms
316,516 KB
testcase_16 AC 152 ms
316,500 KB
testcase_17 AC 143 ms
316,160 KB
testcase_18 AC 142 ms
316,164 KB
testcase_19 AC 147 ms
316,816 KB
testcase_20 AC 152 ms
316,468 KB
testcase_21 AC 140 ms
316,136 KB
testcase_22 AC 141 ms
316,044 KB
testcase_23 AC 151 ms
317,040 KB
testcase_24 AC 160 ms
316,448 KB
testcase_25 AC 146 ms
316,424 KB
testcase_26 AC 157 ms
316,496 KB
testcase_27 AC 147 ms
316,220 KB
testcase_28 AC 152 ms
316,468 KB
testcase_29 AC 140 ms
316,280 KB
testcase_30 AC 155 ms
316,472 KB
testcase_31 AC 158 ms
317,244 KB
testcase_32 AC 156 ms
317,516 KB
testcase_33 AC 156 ms
317,576 KB
testcase_34 AC 142 ms
316,240 KB
testcase_35 AC 143 ms
316,412 KB
testcase_36 AC 157 ms
317,512 KB
testcase_37 AC 147 ms
316,148 KB
testcase_38 AC 142 ms
316,140 KB
testcase_39 AC 145 ms
316,656 KB
testcase_40 AC 153 ms
316,468 KB
testcase_41 AC 153 ms
316,400 KB
testcase_42 AC 155 ms
316,432 KB
testcase_43 AC 145 ms
316,112 KB
testcase_44 AC 150 ms
316,304 KB
testcase_45 AC 154 ms
316,408 KB
testcase_46 AC 153 ms
316,496 KB
testcase_47 AC 143 ms
316,276 KB
testcase_48 AC 157 ms
317,180 KB
testcase_49 AC 155 ms
317,256 KB
testcase_50 AC 147 ms
316,344 KB
testcase_51 AC 147 ms
316,288 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>

int ri() {
	int n;
	scanf("%d", &n);
	return n;
}

template<int mod>
struct ModInt{
	int x;
	ModInt () : x(0) {}
	ModInt (int64_t x) : x(x >= 0 ? x % mod : (mod - -x % mod) % mod) {}
	ModInt &operator += (const ModInt &p){
		if ((x += p.x) >= mod) x -= mod;
		return *this;
	}
	ModInt &operator -= (const ModInt &p) {
		if ((x += mod - p.x) >= mod) x -= mod;
		return *this;
	}
	ModInt &operator *= (const ModInt &p) {
		x = (int64_t) x * p.x % mod;
		return *this;
	}
	ModInt &operator /= (const ModInt &p) {
		*this *= p.inverse();
		return *this;
	}
	ModInt &operator ^= (int64_t p) {
		ModInt res = 1;
		for (; p; p >>= 1) {
			if (p & 1) res *= *this;
			*this *= *this;
		}
		return *this = res;
	}
	ModInt operator - () const { return ModInt(-x); }
	ModInt operator + (const ModInt &p) const { return ModInt(*this) += p; }
	ModInt operator - (const ModInt &p) const { return ModInt(*this) -= p; }
	ModInt operator * (const ModInt &p) const { return ModInt(*this) *= p; }
	ModInt operator / (const ModInt &p) const { return ModInt(*this) /= p; }
	ModInt operator ^ (int64_t p) const { return ModInt(*this) ^= p; }
	bool operator == (const ModInt &p) const { return x == p.x; }
	bool operator != (const ModInt &p) const { return x != p.x; }
	explicit operator int() const { return x; }
	ModInt &operator = (const int p) { x = p; return *this;}
	ModInt inverse() const {
		int a = x, b = mod, u = 1, v = 0, t;
		while (b > 0) {
			t = a / b;
			a -= t * b;
			std::swap(a, b);
			u -= t * v;
			std::swap(u, v);
		}
		return ModInt(u);
	}
	friend std::ostream & operator << (std::ostream &stream, const ModInt<mod> &p) {
		return stream << p.x;
	}
	friend std::istream & operator >> (std::istream &stream, ModInt<mod> &a) {
		int64_t x;
		stream >> x;
		a = ModInt<mod>(x);
		return stream;
	}
};
typedef ModInt<998244353> mint;


#define MAX_NODE (200000 * 20)
template<typename key_t, typename value_t, typename agg_t, typename op_t> struct avl_map {
	static inline constexpr key_t KEY_MIN = std::numeric_limits<key_t>::min();
	static inline constexpr key_t KEY_MAX = std::numeric_limits<key_t>::max();
	struct Node {
		Node *l;
		Node *r;
		int size;
		int height;
		key_t key;
		value_t val;
		agg_t agg;
		op_t op;
		Node *fetch() {
			if (l != avl_map::NONE) l->flush();
			if (r != avl_map::NONE) r->flush();
			size = 1 + l->size + r->size;
			height = 1 + std::max(l->height, r->height);
			agg = l->agg + ((agg_t) val) + r->agg;
			return this;
		}
		Node *flush() {
			if (op != op_t()) {
				op.apply_val(val);
				op.apply_agg(agg);
				if (l) l->op = l->op * op;
				if (r) r->op = r->op * op;
				op = op_t();
			}
			return this;
		}
		Node *rotate_l() {
			Node *new_root = r->flush();
			r = new_root->l;
			new_root->l = this;
			return fetch(), new_root->fetch();
		}
		Node *rotate_r() {
			Node *new_root = l->flush();
			l = new_root->r;
			new_root->r = this;
			return fetch(), new_root->fetch();
		}
		int height_diff() { return l->height - r->height; }
		Node *balance() {
			int dif = height_diff();
			if (dif == 2) {
				if (l->flush()->height_diff() < 0) l = l->rotate_l();
				return rotate_r();
			} else if (dif == -2) {
				if (r->flush()->height_diff() > 0) r = r->rotate_r();
				return rotate_l();
			} else return fetch();
		}
	};
	static Node *nodes, *removed_tmp, *NONE;
	Node *root;
	static int head;
	avl_map (Node *root) : avl_map() { this->root = root; }
	avl_map () {
		if (!head) NONE = nodes = new Node[MAX_NODE], nodes[head++] = {NONE, NONE, 0, 0, key_t(), value_t(), agg_t(), op_t()};
		root = NONE;
	}
	template<class T> static Node *insert(Node *node, const key_t &key, const value_t &val, const T &func) {
		if (node == NONE) return &(nodes[head++] = {NONE, NONE, 1, 1, key, val, val, op_t()});
		node->flush();
		if (key < node->key) node->l = insert(node->l, key, val, func);
		else if (key > node->key) node->r = insert(node->r, key, val, func);
		else func(node, val);
		return node->balance();
	}
	static Node *remove_rightmost(Node *node) {
		node->flush();
		if (node->r != NONE) {
			node->r = remove_rightmost(node->r);
			return node->balance();
		} else return (removed_tmp = node)->l;
	}
	static Node *remove(Node *node, const key_t &key) {
		if (node == NONE) return node;
		node->flush();
		if (key < node->key) node->l = remove(node->l, key);
		else if (key > node->key) node->r = remove(node->r, key);
		else {
			if (node->l == NONE) return node->r;
			node->l = remove_rightmost(node->l);
			removed_tmp->l = node->l;
			removed_tmp->r = node->r;
			return removed_tmp->balance();
		}
		return node->balance();
	}
	static Node *merge_with_root(Node *l, Node *m, Node *r) {
		int dif = l->height - r->height;
		if (-2 <= dif && dif <= 2) {
			m->l = l;
			m->r = r;
			return m->balance();
		} else if (dif > 0) {
			l->flush();
			l->r = merge_with_root(l->r, m, r);
			return l->balance();
		} else {
			r->flush();
			r->l = merge_with_root(l, m, r->l);
			return r->balance();
		}
	}
	static Node *merge(Node *l, Node *r) {
		if (l == NONE) return r;
		if (r == NONE) return l;
		l = remove_rightmost(l);
		return merge_with_root(l, removed_tmp, r);
	}
	static std::pair<Node *, Node *> split(Node *node, const key_t &key) {
		if (node == NONE) return {NONE, NONE};
		node->flush();
		Node *l = node->l;
		Node *r = node->r;
		node->l = node->r = NONE;
		if (key < node->key) {
			auto tmp = split(l, key);
			return {tmp.first, merge_with_root(tmp.second, node, r)};
		} else if (key > node->key) {
			auto tmp = split(r, key);
			return {merge_with_root(l, node, tmp.first), tmp.second};
		} else return {l, merge_with_root(NONE, node->fetch(), r)};
	}
	static agg_t aggregate(Node *node, const key_t &key_l, const key_t &key_r) {
		agg_t res{};
		if (node == NONE || key_r <= key_l) return res;
		node->flush();
		if (key_l == KEY_MIN && key_r == KEY_MAX) return node->agg;
		if (key_l < node->key) res = res + aggregate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r);
		if (key_l <= node->key && key_r > node->key) res = res + (agg_t) node->val;
		if (key_r > node->key + 1) res = res + aggregate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r);
		return res;
	}
	static void operate(Node *node, const key_t &key_l, const key_t &key_r, const op_t &op) {
		if (node == NONE || key_r <= key_l) return;
		node->flush();
		if (key_l == KEY_MIN && key_r == KEY_MAX) {
			node->op = op;
			node->flush();
			return;
		}
		if (key_l < node->key) operate(node->l, key_l, key_r > node->key ? KEY_MAX : key_r, op);
		if (key_l <= node->key && key_r > node->key) op.apply_val(node->val);
		if (key_r > node->key + 1) operate(node->r, key_l <= node->key ? KEY_MIN : key_l, key_r, op);
		node->fetch();
	}
	static void dump(Node *node, std::vector<std::pair<key_t, value_t> > &res) {
		node->flush();
		if (node->l != NONE) dump(node->l, res);
		res.push_back({node->key, node->val});
		if (node->r != NONE) dump(node->r, res);
	}
	// member functions
	int size() const { return root->size; }
	void assign(const key_t &key, const value_t &val) {
		root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = val; });
	}
	void add(const key_t &key, const value_t &val) {
		root = insert(root, key, val, [] (Node *node, const value_t &val) { node->val = ((agg_t) node->val) + (agg_t) val; });
	}
	void remove(const key_t &key) { root = remove(root, key); }
	static avl_map merge(const std::vector<avl_map> &a) {
		Node *res = a[0].root;
		for (int i = 1; i < (int) a.size(); i++) res = merge(res, a[i].root);
		return avl_map(res);
	}
	std::pair<avl_map, avl_map> split(const key_t &key) {
		auto tmp = split(root, key);
		return {avl_map(tmp.first), avl_map(tmp.second)};
	}
	agg_t aggregate(const key_t &l, const key_t &r) { return aggregate(root, l, r); }
	agg_t aggregate_left(const key_t &r) { return aggregate(root, KEY_MIN, r); }
	agg_t aggregate_right(const key_t &l) { return aggregate(root, l, KEY_MAX); }
	void operate(const key_t &l, const key_t &r, const op_t &op) { operate(root, l, r, op); }
	void operate_left(const key_t &r, const op_t &op) { operate(root, KEY_MIN, r, op); }
	void operate_right(const key_t &l, const op_t &op) { operate(root, l, KEY_MAX, op); }
	std::vector<std::pair<key_t, value_t> > dump() const {
		std::vector<std::pair<key_t, value_t> > res;
		dump(root, res);
		return res;
	}
	static void reset() { if (head) head = 1; }
	void clear() { root = NONE; }
};
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
	*avl_map<key_t, value_t, agg_t, op_t>::nodes;
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
	*avl_map<key_t, value_t, agg_t, op_t>::removed_tmp;
template<typename key_t, typename value_t, typename agg_t, typename op_t> typename avl_map<key_t, value_t, agg_t, op_t>::Node
	*avl_map<key_t, value_t, agg_t, op_t>::NONE;
template<typename key_t, typename value_t, typename agg_t, typename op_t> int avl_map<key_t, value_t, agg_t, op_t>::head = 0;



template<typename T, size_t n> std::array<T, n> array_add(const std::array<T, n> &a, const std::array<T, n> &b) {
	std::array<T, n> res;
	for (int i = 0; i < n; i++) res[i] = a[i] + b[i];
	return res;
}

struct value_t {
	std::array<mint, 4> vec{};
	value_t operator + (const value_t &rhs) {
		std::array<mint, 4> res;
		for (int i = 0; i < 4; i++) res[i] = vec[i] + rhs.vec[i];
		return {res};
	}
};
struct op_t {
	std::array<std::array<mint, 2>, 2> mat = {{{1, 0}, {0, 1}}};
	
	static op_t zero() { return op_t{{{{0, 0}, {0, 0}}}}; };
	template<class T> void apply(T &val) const {
		std::array<mint, 4> tmp{};
		for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 0] += val.vec[i + 0] * mat[i][j];
		for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) tmp[j + 2] += val.vec[i + 2] * mat[i][j];
		val.vec = tmp;
	}
	void apply_val(value_t &val) const { apply(val); }
	void apply_agg(value_t &sum) const { apply(sum); }
	bool operator != (const op_t &rhs) const { return mat != rhs.mat; }
	op_t operator * (const op_t &rhs) const {
		op_t res;
		for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++)
			res.mat[i][j] = mat[i][0] * rhs.mat[0][j] + mat[i][1] * rhs.mat[1][j];
		return res;
	}
};
using map = avl_map<int, value_t, value_t, op_t>;

std::vector<std::vector<int> > hen;

void merge(map &lhs, map &rhs) {
	if (lhs.size() < rhs.size()) std::swap(lhs, rhs);
	
	std::vector<std::pair<int, std::array<mint, 4> > > add_list;
	auto dump = rhs.dump();
	for (int i = dump.size(); i--; ) {
		int key = dump[i].first;
		auto val = dump[i].second.vec;
		auto cur_sum = lhs.aggregate_left(key).vec;
		mint r0 = val[0] * cur_sum[0];
		mint r1 = val[0] * cur_sum[1] + val[1] * cur_sum[0];
		mint r2 = r0 * key;
		mint r3 = r1 * key;
		add_list.push_back({key, {r0, r1, r2, r3}});
	}
	mint sum[4] = { 0 };
	int last_key = map::KEY_MIN;
	for (int i = 0; i < (int) dump.size(); i++) {
		int key = dump[i].first;
		auto val = dump[i].second.vec;
		lhs.operate(last_key, key, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});
		last_key = key;
		for (int j = 0; j < 4; j++) sum[j] += val[j];
	}
	lhs.operate(last_key, map::KEY_MAX, op_t{{{{sum[0], sum[1]}, {0, sum[0]}}}});
	
	for (auto i : add_list) lhs.add(i.first, value_t{i.second});
}

std::vector<int> cnt;
std::vector<map> dp;
std::vector<int> a;
std::vector<mint> power2;
int n;
mint res = 0;

void dfs(int i, int prev) {
	dp[i].assign(a[i], value_t({1, 1, a[i], a[i]}));
	for (auto j : hen[i]) if (j != prev) {
		dfs(j, i);
		cnt[i] += cnt[j];
		dp[j].assign(-1, value_t({power2[cnt[j] - 1], 0, 0, -power2[cnt[j] - 1]}));
		merge(dp[i], dp[j]);
	}
	res += dp[i].aggregate(map::KEY_MIN, map::KEY_MAX).vec[3] * power2[n - cnt[i] - (prev != -1)];
}


int main() {
	n = ri();
	a.resize(n);
	for (auto &i : a) i = ri();
	hen.resize(n);
	for (int i = 1; i < n; i++) {
		int x = ri() - 1;
		int y = ri() - 1;
		hen[x].push_back(y);
		hen[y].push_back(x);
	}
	dp.resize(n);
	cnt.resize(n, 1);
	power2.resize(n + 1, 1);
	for (int i = 1; i <= n; i++) power2[i] = power2[i - 1] + power2[i - 1];
	res = 0;
	dfs(0, -1);
	std::cout << res << std::endl;
	return 0;
}
0