結果
問題 | No.2116 Making Forest Hard |
ユーザー |
![]() |
提出日時 | 2022-10-11 04:32:53 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 5,679 ms / 8,000 ms |
コード長 | 12,744 bytes |
コンパイル時間 | 6,516 ms |
コンパイル使用メモリ | 304,896 KB |
最終ジャッジ日時 | 2025-02-08 01:21:12 |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 53 |
ソースコード
#if defined(LOCAL)#include<stdc++.h>#else#include<bits/stdc++.h>#endif#include<random>#pragma GCC optimize("Ofast")//#pragma GCC target("avx2")#pragma GCC optimize("unroll-loops")using namespace std;//#include<boost/multiprecision/cpp_int.hpp>//#include<boost/multiprecision/cpp_dec_float.hpp>//namespace mp=boost::multiprecision;//#define mulint mp::cpp_int//#define mulfloat mp::cpp_dec_float_100struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<<fixed<<setprecision(15);}} __init;//#define INF (1<<30)#define LINF (lint)(1LL<<56)#define MINF (lint)(2e18)#define endl "\n"#define rep(i,n) for(lint (i)=0;(i)<(n);(i)++)#define reprev(i,n) for(lint (i)=(n-1);(i)>=0;(i)--)#define flc(x) __builtin_popcountll(x)#define pint pair<int,int>#define pdouble pair<double,double>#define plint pair<lint,lint>#define fi first#define se second#define all(x) x.begin(),x.end()//#define vec vector<lint>#define nep(x) next_permutation(all(x))typedef long long lint;int dx[8]={1,1,0,-1,-1,-1,0,1};int dy[8]={0,1,1,1,0,-1,-1,-1};const int MAX_N=3e5+5;template<class T>bool chmax(T &a,const T &b){if(a<b){a=b;return 1;}return 0;}template<class T>bool chmin(T &a,const T &b){if(b<a){a=b;return 1;}return 0;}//vector<int> bucket[MAX_N/1000];//constexpr int MOD=1000000007;constexpr int MOD=998244353;#include<atcoder/all>using namespace atcoder;typedef __int128_t llint;using mint=modint998244353;//https://github.com/NachiaVivias/cp-library//https://judge.yosupo.jp/submission/77455struct SplayTreeByIdx{struct S{mint dp1,dp2,dp1key,dp2key;};static S op(S l, S r){ return {l.dp1+r.dp1,l.dp2+r.dp2,l.dp1key+r.dp1key,l.dp2key+r.dp2key}; }static S e(){ return {0,0,0,0}; }struct F{mint mt00,mt01,mt10,mt11; };static S mapping(F f, S x){ return { f.mt00*x.dp1+f.mt01*x.dp2,f.mt10*x.dp1+f.mt11*x.dp2,f.mt00*x.dp1key+f.mt01*x.dp2key,f.mt10*x.dp1key+f.mt11*x.dp2key};}static F composition(F f, F x){ return { f.mt00*x.mt00+f.mt01*x.mt10,f.mt00*x.mt01+f.mt01*x.mt11,f.mt10*x.mt00+f.mt11*x.mt10,f.mt10*x.mt01+f.mt11*x.mt11}; }static F id(){ return {1,0,0,1}; }struct Node{Node *l = 0, *r = 0, *p = 0;S a = e();S prod = e();F f = id();bool propagated = true;int z = 0;int sumz = 0;int rev = 0;};Node *NIL = nullptr;Node *R;void prepareDown(Node* c){if(!c->propagated){if(c->l != NIL){c->l->a = mapping(c->f, c->l->a);c->l->prod = mapping(c->f, c->l->prod);c->l->f = composition(c->f, c->l->f);c->l->propagated = false;}if(c->r != NIL){c->r->a = mapping(c->f, c->r->a);c->r->prod = mapping(c->f, c->r->prod);c->r->f = composition(c->f, c->r->f);c->r->propagated = false;}c->f = id();c->propagated = true;}if(c->rev){swap(c->l, c->r);if(c->l != NIL) c->l->rev ^= 1;if(c->r != NIL) c->r->rev ^= 1;c->rev = 0;}}void prepareUp(Node* c){c->sumz = c->l->sumz + c->r->sumz + 1;c->prod = op(op(c->l->prod,c->a),c->r->prod);}SplayTreeByIdx(){if(!NIL){NIL = new Node();NIL->l = NIL->r = NIL->p = NIL;R = NIL;}}Node*& parentchild(Node* p){if(p->p == NIL) return R;if(p->p->l == p) return p->p->l;else return p->p->r;}void rotL(Node* c){Node* p = c->p;parentchild(p) = c;c->p = p->p;p->p = c;if(c->l != NIL) c->l->p = p;p->r = c->l;c->l = p;}void rotR(Node* c){Node* p = c->p;parentchild(p) = c;c->p = p->p;p->p = c;if(c->r != NIL) c->r->p = p;p->l = c->r;c->r = p;}void splay(Node* c){while(c->p != NIL){Node* p = c->p;Node* pp = p->p;if(p->l == c){if(pp == NIL){ rotR(c); }else if(pp->l == p){ rotR(p); rotR(c); }else if(pp->r == p){ rotR(c); rotL(c); }}else{if(pp == NIL){ rotL(c); }else if(pp->r == p){ rotL(p); rotL(c); }else if(pp->l == p){ rotL(c); rotR(c); }}if(pp != NIL) prepareUp(pp);if(p != NIL) prepareUp(p);}prepareUp(c);}Node* kth_element(int k){if(k >= R->sumz) return NIL;Node* c = R;while(true){prepareDown(c);auto cl = c->l;if(cl->sumz == k) break;if(cl->sumz > k){ c = cl; continue; }k -= cl->sumz + 1;c = c->r;}prepareDown(c);splay(c);return c;}Node* insert_at(int k, S x){Node* nx = new Node(*NIL);nx->z = nx->sumz = 1;nx->a = nx->prod = x;if(k == R->sumz){nx->l = R;if(R != NIL) R->p = nx;R = nx;prepareUp(nx);return nx;}auto p = kth_element(k);nx->l = p->l;nx->r = p;R = nx;if(p->l != NIL){prepareDown(p->l);p->l->p = nx;}p->p = nx;p->l = NIL;prepareUp(p);prepareUp(nx);return nx;}void erase_at(int k){if(k >= R->sumz) return;auto toerase = NIL;if(k == 0){kth_element(0);prepareDown(R);toerase = R;R = R->r;if(R != NIL) R->p = NIL;}else{kth_element(k-1);prepareDown(R);auto c = R->r;prepareDown(c);while(c->l != NIL){c = c->l;prepareDown(c);}auto p = c->p;toerase = c;parentchild(c) = c->r;if(c->r != NIL) c->r->p = p;splay(p);//delete p;}}Node* between(int l, int r){if(l >= r) return NIL;if(l == 0 && r == R->sumz) return R;if(l == 0) return kth_element(r)->l;if(r == R->sumz) return kth_element(l-1)->r;auto lp = kth_element(l-1);auto rp = kth_element(r);while(rp->l != lp){auto p = lp->p;prepareDown(p);prepareDown(lp);if(p->l == lp) rotR(lp);else rotL(lp);prepareUp(p);prepareUp(lp);}return lp->r;}void reverse(int l, int r){if(l >= r) return;auto c = between(l,r);c->rev ^= 1;prepareDown(c);splay(c);}void apply(int l, int r, F f){if(l >= r) return;auto c = between(l,r);c->a = mapping(f,c->a);c->prod = mapping(f,c->prod);c->f = composition(f,c->f);c->propagated = false;prepareDown(c);splay(c);}S prod(int l, int r){if(l >= r) return e();return between(l,r)->prod;}};template<typename U = unsigned, int B = 32>class lazy_binary_trie {struct node {int cnt;U lazy;node *ch[2];node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {}};void push(node* t, int b) {if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]);if (t->ch[0]) t->ch[0]->lazy ^= t->lazy;if (t->ch[1]) t->ch[1]->lazy ^= t->lazy;t->lazy = 0;}node* add(node* t, U val, int b = B - 1) {if (!t) t = new node;t->cnt += 1;if (b < 0) return t;push(t, b);bool f = (val >> (U)b) & (U)1;t->ch[f] = add(t->ch[f], val, b - 1);return t;}node* sub(node* t, U val, int b = B - 1) {assert(t);t->cnt -= 1;if (t->cnt == 0){delete t;return nullptr;}if (b < 0) return t;push(t, b);bool f = (val >> (U)b) & (U)1;t->ch[f] = sub(t->ch[f], val, b - 1);return t;}U get_min(node* t, U val, int b = B - 1) {assert(t);if (b < 0) return 0;push(t, b);bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f];return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b);}U get(node* t, int k, int b = B - 1) {if (b < 0) return 0;push(t, b);int m = t->ch[0] ? t->ch[0]->cnt : 0;return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b);}int count_lower(node* t, U val, int b = B - 1) {if (!t || b < 0) return 0;push(t, b);bool f = (val >> (U)b) & (U)1;return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1);}node *root;public:lazy_binary_trie() : root(nullptr) {}int size() const {return root ? root->cnt : 0;}bool empty() const {return !root;}void insert(U val) {root = add(root, val);}void erase(U val) {root = sub(root, val);}void xor_all(U val) {if (root) root->lazy ^= val;}U max_element(U bias = 0) {return get_min(root, ~bias);}U min_element(U bias = 0) {return get_min(root, bias);}int lower_bound(U val) { // return idreturn count_lower(root, val);}int upper_bound(U val) { // return idreturn count_lower(root, val + 1);}U operator[](int k) {assert(0 <= k && k < size());return get(root, k);}U at(int k) {assert(0 <= k && k < size());return get(root, k);}int count(U val) {if (!root) return 0;node *t = root;for (int i = B - 1; i >= 0; i--) {push(t, i);t = t->ch[(val >> (U)i) & (U)1];if (!t) return 0;}return t->cnt;}void del(node *t){if(t->ch[0]!=nullptr) del(t->ch[0]);if(t->ch[1]!=nullptr) del(t->ch[1]);delete t;}void clear(){node *t=root;if(t->ch[0]!=nullptr) del(t->ch[0]);if(t->ch[1]!=nullptr) del(t->ch[1]);delete t;}};SplayTreeByIdx dp[100005];vector<SplayTreeByIdx*> dp_pt;mint ans=0;int N;vector<int> edge[100005];int sub_size[100005];mint pow2[100005];lint A[100005];int pt[100005];vector<lazy_binary_trie<int,30>> bt(100000);vector<lazy_binary_trie<int,30>*> bt_pt(100000);bool merge(int left,int right){int rev=false;if(bt_pt[left]->size()<bt_pt[right]->size()) swap(dp_pt[left],dp_pt[right]),swap(bt_pt[left],bt_pt[right]),rev=true; //マージテクvector<SplayTreeByIdx::S> adds;int N=bt_pt[right]->size();rep(i,N){ //key未満SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1);int key=bt_pt[right]->at(i);int hi=bt_pt[left]->lower_bound(key);SplayTreeByIdx::S lo_sum=dp_pt[left]->prod(0,hi);mint dp1=now.dp1*lo_sum.dp1;mint dp2=now.dp1*lo_sum.dp2+now.dp2*lo_sum.dp1;mint dp1_key=dp1*key;mint dp2_key=dp2*key;adds.push_back({dp1,dp2,dp1_key,dp2_key});}int before=0;mint sum[2];rep(i,N){SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1);int key=bt_pt[right]->at(i);int hi=bt_pt[left]->lower_bound(key);if(before!=hi){dp_pt[left]->apply(before,hi,{sum[0],0,sum[1],sum[0]});}sum[0]+=now.dp1,sum[1]+=now.dp2;before=hi;}if(before!=bt_pt[left]->size()){dp_pt[left]->apply(before,bt_pt[left]->size(),{sum[0],0,sum[1],sum[0]});}int idx=0;for(auto e:adds){int key=bt_pt[right]->at(idx);int hi=bt_pt[left]->lower_bound(key);if(bt_pt[left]->count(key)){SplayTreeByIdx::S res=dp_pt[left]->prod(hi,hi+1);res.dp1+=e.dp1,res.dp2+=e.dp2,res.dp1key+=e.dp1key,res.dp2key+=e.dp2key;dp_pt[left]->erase_at(hi);dp_pt[left]->insert_at(hi,res);}else{dp_pt[left]->insert_at(hi,e);bt_pt[left]->insert(key);}idx++;}return rev;}void dfs(int now,int par){for(auto e:edge[now]){if(e==par) continue;dfs(e,now);sub_size[now]+=sub_size[e];if(bt_pt[e]->count(0)){SplayTreeByIdx::S zero=dp_pt[e]->prod(0,1);dp_pt[e]->erase_at(0);zero.dp1+=pow2[sub_size[e]-1];dp_pt[e]->insert_at(0,zero);}else{dp_pt[e]->insert_at(0,{pow2[sub_size[e]-1],0,0,0});bt_pt[e]->insert(0);}merge(now,e);bt_pt[e]->clear();}ans+=dp_pt[now]->prod(0,bt_pt[now]->size()).dp2key*pow2[N-1-sub_size[now]+(now==0)];}int main(void){cin >> N;rep(i,N) cin >> A[i];rep(i,N) sub_size[i]=1;pow2[0]=1;rep(i,100004) pow2[i+1]=pow2[i]*2;rep(i,100004) pt[i]=i;rep(i,N) bt_pt[i]=&bt[i];rep(i,N){dp[i].insert_at(0,{1,1,A[i],A[i]});bt_pt[i]->insert(A[i]);dp_pt.push_back(&dp[i]);}rep(i,N-1){int u,v;cin >> u >> v;u--,v--;edge[u].push_back(v);edge[v].push_back(u);}dfs(0,-1);cout << ans.val() << endl;}