結果
問題 | No.2116 Making Forest Hard |
ユーザー | bayashiko |
提出日時 | 2022-10-26 22:03:24 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
TLE
|
実行時間 | - |
コード長 | 13,180 bytes |
コンパイル時間 | 6,908 ms |
コンパイル使用メモリ | 312,388 KB |
実行使用メモリ | 103,352 KB |
最終ジャッジ日時 | 2024-07-04 10:23:00 |
合計ジャッジ時間 | 24,763 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 16 ms
19,796 KB |
testcase_01 | AC | 18 ms
18,944 KB |
testcase_02 | AC | 4,239 ms
93,012 KB |
testcase_03 | AC | 2,165 ms
90,140 KB |
testcase_04 | TLE | - |
testcase_05 | -- | - |
testcase_06 | -- | - |
testcase_07 | -- | - |
testcase_08 | -- | - |
testcase_09 | -- | - |
testcase_10 | -- | - |
testcase_11 | -- | - |
testcase_12 | -- | - |
testcase_13 | -- | - |
testcase_14 | -- | - |
testcase_15 | -- | - |
testcase_16 | -- | - |
testcase_17 | -- | - |
testcase_18 | -- | - |
testcase_19 | -- | - |
testcase_20 | -- | - |
testcase_21 | -- | - |
testcase_22 | -- | - |
testcase_23 | -- | - |
testcase_24 | -- | - |
testcase_25 | -- | - |
testcase_26 | -- | - |
testcase_27 | -- | - |
testcase_28 | -- | - |
testcase_29 | -- | - |
testcase_30 | -- | - |
testcase_31 | -- | - |
testcase_32 | -- | - |
testcase_33 | -- | - |
testcase_34 | -- | - |
testcase_35 | -- | - |
testcase_36 | -- | - |
testcase_37 | -- | - |
testcase_38 | -- | - |
testcase_39 | -- | - |
testcase_40 | -- | - |
testcase_41 | -- | - |
testcase_42 | -- | - |
testcase_43 | -- | - |
testcase_44 | -- | - |
testcase_45 | -- | - |
testcase_46 | -- | - |
testcase_47 | -- | - |
testcase_48 | -- | - |
testcase_49 | -- | - |
testcase_50 | -- | - |
testcase_51 | -- | - |
testcase_52 | -- | - |
testcase_53 | -- | - |
testcase_54 | -- | - |
ソースコード
#if defined(LOCAL) #include<stdc++.h> #else #include<bits/stdc++.h> #endif #include<random> #pragma GCC optimize("Ofast") //#pragma GCC target("avx2") #pragma GCC optimize("unroll-loops") using namespace std; //#include<boost/multiprecision/cpp_int.hpp> //#include<boost/multiprecision/cpp_dec_float.hpp> //namespace mp=boost::multiprecision; //#define mulint mp::cpp_int //#define mulfloat mp::cpp_dec_float_100 struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<<fixed<<setprecision(15);}} __init; //#define INF (1<<30) #define LINF (lint)(1LL<<56) #define MINF (lint)(2e18) #define endl "\n" #define rep(i,n) for(lint (i)=0;(i)<(n);(i)++) #define reprev(i,n) for(lint (i)=(n-1);(i)>=0;(i)--) #define flc(x) __builtin_popcountll(x) #define pint pair<int,int> #define pdouble pair<double,double> #define plint pair<lint,lint> #define fi first #define se second #define all(x) x.begin(),x.end() //#define vec vector<lint> #define nep(x) next_permutation(all(x)) typedef long long lint; int dx[8]={1,1,0,-1,-1,-1,0,1}; int dy[8]={0,1,1,1,0,-1,-1,-1}; const int MAX_N=3e5+5; template<class T>bool chmax(T &a,const T &b){if(a<b){a=b;return 1;}return 0;} template<class T>bool chmin(T &a,const T &b){if(b<a){a=b;return 1;}return 0;} //vector<int> bucket[MAX_N/1000]; //constexpr int MOD=1000000007; constexpr int MOD=998244353; #include<atcoder/all> using namespace atcoder; typedef __int128_t llint; using mint=modint998244353; vector<int> rs; //https://github.com/NachiaVivias/cp-library //https://judge.yosupo.jp/submission/77455 struct SplayTreeByIdx{ struct S{mint dp1,dp2,dp1key,dp2key;}; static S op(S l, S r){ return {l.dp1+r.dp1,l.dp2+r.dp2,l.dp1key+r.dp1key,l.dp2key+r.dp2key}; } static S e(){ return {0,0,0,0}; } struct F{mint mt00,mt01,mt10,mt11; }; static S mapping(F f, S x){ return { f.mt00*x.dp1+f.mt01*x.dp2,f.mt10*x.dp1+f.mt11*x.dp2,f.mt00*x.dp1key+f.mt01*x.dp2key,f.mt10*x.dp1key+f.mt11*x.dp2key};} static F composition(F f, F x){ return { f.mt00*x.mt00+f.mt01*x.mt10,f.mt00*x.mt01+f.mt01*x.mt11,f.mt10*x.mt00+f.mt11*x.mt10,f.mt10*x.mt01+f.mt11*x.mt11}; } static F id(){ return {1,0,0,1}; } struct Node{ Node *l = 0, *r = 0, *p = 0; S a = e(); S prod = e(); F f = id(); bool propagated = true; int z = 0; int sumz = 0; int rev = 0; }; Node *NIL = nullptr; Node *R; void prepareDown(Node* c){ if(!c->propagated){ if(c->l != NIL){ c->l->a = mapping(c->f, c->l->a); c->l->prod = mapping(c->f, c->l->prod); c->l->f = composition(c->f, c->l->f); c->l->propagated = false; } if(c->r != NIL){ c->r->a = mapping(c->f, c->r->a); c->r->prod = mapping(c->f, c->r->prod); c->r->f = composition(c->f, c->r->f); c->r->propagated = false; } c->f = id(); c->propagated = true; } if(c->rev){ swap(c->l, c->r); if(c->l != NIL) c->l->rev ^= 1; if(c->r != NIL) c->r->rev ^= 1; c->rev = 0; } } void prepareUp(Node* c){ c->sumz = c->l->sumz + c->r->sumz + 1; c->prod = op(op(c->l->prod,c->a),c->r->prod); } SplayTreeByIdx(){ if(!NIL){ NIL = new Node(); NIL->l = NIL->r = NIL->p = NIL; R = NIL; } } Node*& parentchild(Node* p){ if(p->p == NIL) return R; if(p->p->l == p) return p->p->l; else return p->p->r; } void rotL(Node* c){ Node* p = c->p; parentchild(p) = c; c->p = p->p; p->p = c; if(c->l != NIL) c->l->p = p; p->r = c->l; c->l = p; } void rotR(Node* c){ Node* p = c->p; parentchild(p) = c; c->p = p->p; p->p = c; if(c->r != NIL) c->r->p = p; p->l = c->r; c->r = p; } void splay(Node* c){ while(c->p != NIL){ Node* p = c->p; Node* pp = p->p; if(p->l == c){ if(pp == NIL){ rotR(c); } else if(pp->l == p){ rotR(p); rotR(c); } else if(pp->r == p){ rotR(c); rotL(c); } } else{ if(pp == NIL){ rotL(c); } else if(pp->r == p){ rotL(p); rotL(c); } else if(pp->l == p){ rotL(c); rotR(c); } } if(pp != NIL) prepareUp(pp); if(p != NIL) prepareUp(p); } prepareUp(c); } Node* kth_element(int k){ if(k >= R->sumz) return NIL; Node* c = R; while(true){ prepareDown(c); auto cl = c->l; if(cl->sumz == k) break; if(cl->sumz > k){ c = cl; continue; } k -= cl->sumz + 1; c = c->r; } prepareDown(c); splay(c); return c; } Node* insert_at(int k, S x){ Node* nx = new Node(*NIL); nx->z = nx->sumz = 1; nx->a = nx->prod = x; if(k == R->sumz){ nx->l = R; if(R != NIL) R->p = nx; R = nx; prepareUp(nx); return nx; } auto p = kth_element(k); nx->l = p->l; nx->r = p; R = nx; if(p->l != NIL){ prepareDown(p->l); p->l->p = nx; } p->p = nx; p->l = NIL; prepareUp(p); prepareUp(nx); return nx; } void erase_at(int k){ if(k >= R->sumz) return; auto toerase = NIL; if(k == 0){ kth_element(0); prepareDown(R); toerase = R; R = R->r; if(R != NIL) R->p = NIL; } else{ kth_element(k-1); prepareDown(R); auto c = R->r; prepareDown(c); while(c->l != NIL){ c = c->l; prepareDown(c); } auto p = c->p; toerase = c; parentchild(c) = c->r; if(c->r != NIL) c->r->p = p; splay(p); //delete p; } } Node* between(int l, int r){ if(l >= r) return NIL; if(l == 0 && r == R->sumz) return R; if(l == 0) return kth_element(r)->l; if(r == R->sumz) return kth_element(l-1)->r; auto lp = kth_element(l-1); auto rp = kth_element(r); while(rp->l != lp){ auto p = lp->p; prepareDown(p); prepareDown(lp); if(p->l == lp) rotR(lp); else rotL(lp); prepareUp(p); prepareUp(lp); } return lp->r; } void reverse(int l, int r){ if(l >= r) return; auto c = between(l,r); c->rev ^= 1; prepareDown(c); splay(c); } void apply(int l, int r, F f){ if(l >= r) return; auto c = between(l,r); c->a = mapping(f,c->a); c->prod = mapping(f,c->prod); c->f = composition(f,c->f); c->propagated = false; prepareDown(c); splay(c); } S prod(int l, int r){ if(l >= r) return e(); return between(l,r)->prod; } void del(Node* tmp){ if(tmp->l!=NIL) del(tmp->l); if(tmp->r!=NIL) del(tmp->r); delete tmp; } void clear(){ while(R->p!=NIL) R = R->p; del(R); } }; template<typename U = unsigned, int B = 32> class lazy_binary_trie { struct node { int cnt; U lazy; node *ch[2]; node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {} }; void push(node* t, int b) { if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]); if (t->ch[0]) t->ch[0]->lazy ^= t->lazy; if (t->ch[1]) t->ch[1]->lazy ^= t->lazy; t->lazy = 0; } node* add(node* t, U val, int b = B - 1) { if (!t) t = new node; t->cnt += 1; if (b < 0) return t; push(t, b); bool f = (val >> (U)b) & (U)1; t->ch[f] = add(t->ch[f], val, b - 1); return t; } node* sub(node* t, U val, int b = B - 1) { assert(t); t->cnt -= 1; if (t->cnt == 0){ delete t; return nullptr; } if (b < 0) return t; push(t, b); bool f = (val >> (U)b) & (U)1; t->ch[f] = sub(t->ch[f], val, b - 1); return t; } U get_min(node* t, U val, int b = B - 1) { assert(t); if (b < 0) return 0; push(t, b); bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f]; return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b); } U get(node* t, int k, int b = B - 1) { if (b < 0) return 0; push(t, b); int m = t->ch[0] ? t->ch[0]->cnt : 0; return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b); } int count_lower(node* t, U val, int b = B - 1) { if (!t || b < 0) return 0; push(t, b); bool f = (val >> (U)b) & (U)1; return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1); } node *root; public: lazy_binary_trie() : root(nullptr) {} int size() const { return root ? root->cnt : 0; } bool empty() const { return !root; } void insert(U val) { root = add(root, val); } void erase(U val) { root = sub(root, val); } void xor_all(U val) { if (root) root->lazy ^= val; } U max_element(U bias = 0) { return get_min(root, ~bias); } U min_element(U bias = 0) { return get_min(root, bias); } int lower_bound(U val) { // return id return count_lower(root, val); } int upper_bound(U val) { // return id return count_lower(root, val + 1); } U operator[](int k) { assert(0 <= k && k < size()); return get(root, k); } U at(int k) { assert(0 <= k && k < size()); return get(root, k); } int count(U val) { if (!root) return 0; node *t = root; for (int i = B - 1; i >= 0; i--) { push(t, i); t = t->ch[(val >> (U)i) & (U)1]; if (!t) return 0; } return t->cnt; } void del(node *t){ if(t->ch[0]!=nullptr) del(t->ch[0]); if(t->ch[1]!=nullptr) del(t->ch[1]); delete t; } void clear(){ node *t=root; if(t->ch[0]!=nullptr) del(t->ch[0]); if(t->ch[1]!=nullptr) del(t->ch[1]); delete t; } }; SplayTreeByIdx dp[100005]; vector<SplayTreeByIdx*> dp_pt; mint ans=0; int N; vector<int> edge[100005]; int sub_size[100005]; mint pow2[100005]; lint A[100005]; int pt[100005]; vector<lazy_binary_trie<int,17>> bt(100000); vector<lazy_binary_trie<int,17>*> bt_pt(100000); bool merge(int left,int right){ int rev=false; //if(bt_pt[left]->size()<bt_pt[right]->size()) swap(dp_pt[left],dp_pt[right]),swap(bt_pt[left],bt_pt[right]),rev=true; //マージテク vector<SplayTreeByIdx::S> adds; int N=bt_pt[right]->size(); rep(i,N){ //key未満 SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1); int key=bt_pt[right]->at(i); int hi=bt_pt[left]->lower_bound(key); SplayTreeByIdx::S lo_sum=dp_pt[left]->prod(0,hi); mint dp1=now.dp1*lo_sum.dp1; mint dp2=now.dp1*lo_sum.dp2+now.dp2*lo_sum.dp1; mint dp1_key=dp1*rs[key]; mint dp2_key=dp2*rs[key]; adds.push_back({dp1,dp2,dp1_key,dp2_key}); } int before=0; mint sum[2]; rep(i,N){ SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1); int key=bt_pt[right]->at(i); int hi=bt_pt[left]->lower_bound(key); if(before!=hi){ dp_pt[left]->apply(before,hi,{sum[0],0,sum[1],sum[0]}); } sum[0]+=now.dp1,sum[1]+=now.dp2; before=hi; } if(before!=bt_pt[left]->size()){ dp_pt[left]->apply(before,bt_pt[left]->size(),{sum[0],0,sum[1],sum[0]}); } int idx=0; for(auto e:adds){ int key=bt_pt[right]->at(idx); int hi=bt_pt[left]->lower_bound(key); if(bt_pt[left]->count(key)){ SplayTreeByIdx::S res=dp_pt[left]->prod(hi,hi+1); res.dp1+=e.dp1,res.dp2+=e.dp2,res.dp1key+=e.dp1key,res.dp2key+=e.dp2key; dp_pt[left]->erase_at(hi); dp_pt[left]->insert_at(hi,res); } else{ dp_pt[left]->insert_at(hi,e); bt_pt[left]->insert(key); } idx++; } return rev; } void dfs(int now,int par){ for(auto e:edge[now]){ if(e==par) continue; dfs(e,now); sub_size[now]+=sub_size[e]; if(bt_pt[e]->count(0)){ SplayTreeByIdx::S zero=dp_pt[e]->prod(0,1); dp_pt[e]->erase_at(0); zero.dp1+=pow2[sub_size[e]-1]; dp_pt[e]->insert_at(0,zero); } else{ dp_pt[e]->insert_at(0,{pow2[sub_size[e]-1],0,0,0}); bt_pt[e]->insert(0); } merge(now,e); bt_pt[e]->clear(); dp_pt[e]->clear(); } ans+=dp_pt[now]->prod(0,bt_pt[now]->size()).dp2key*pow2[N-1-sub_size[now]+(now==0)]; } int main(void){ cin >> N; rep(i,N) cin >> A[i]; vector<int> as; rep(i,N) as.push_back(A[i]); sort(all(as)); rs.push_back(0); int pre=-1; rep(i,N) if(pre!=as[i]) rs.push_back(as[i]),pre=as[i]; rep(i,N) sub_size[i]=1; pow2[0]=1; rep(i,100004) pow2[i+1]=pow2[i]*2; rep(i,100004) pt[i]=i; rep(i,N) bt_pt[i]=&bt[i]; rep(i,N){ dp[i].insert_at(0,{1,1,A[i],A[i]}); bt_pt[i]->insert(lower_bound(all(rs),A[i])-rs.begin()); dp_pt.push_back(&dp[i]); } rep(i,N-1){ int u,v; cin >> u >> v; u--,v--; edge[u].push_back(v); edge[v].push_back(u); } dfs(0,-1); cout << ans.val() << endl; }