結果

問題 No.2116 Making Forest Hard
ユーザー bayashikobayashiko
提出日時 2022-10-26 22:04:28
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 13,115 bytes
コンパイル時間 6,784 ms
コンパイル使用メモリ 310,256 KB
実行使用メモリ 100,036 KB
最終ジャッジ日時 2023-09-17 15:16:52
合計ジャッジ時間 21,899 ms
ジャッジサーバーID
(参考情報)
judge14 / judge13
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 17 ms
18,976 KB
testcase_01 AC 17 ms
18,896 KB
testcase_02 AC 2,994 ms
92,952 KB
testcase_03 AC 1,586 ms
89,928 KB
testcase_04 TLE -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
testcase_40 -- -
testcase_41 -- -
testcase_42 -- -
testcase_43 -- -
testcase_44 -- -
testcase_45 -- -
testcase_46 -- -
testcase_47 -- -
testcase_48 -- -
testcase_49 -- -
testcase_50 -- -
testcase_51 -- -
testcase_52 -- -
testcase_53 -- -
testcase_54 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#if defined(LOCAL)
#include<stdc++.h>
#else
#include<bits/stdc++.h>
#endif
#include<random>
#pragma GCC optimize("Ofast")
//#pragma GCC target("avx2")
#pragma GCC optimize("unroll-loops")
using namespace std;
//#include<boost/multiprecision/cpp_int.hpp>
//#include<boost/multiprecision/cpp_dec_float.hpp>
//namespace mp=boost::multiprecision;
//#define mulint mp::cpp_int
//#define mulfloat mp::cpp_dec_float_100
struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<<fixed<<setprecision(15);}} __init;
//#define INF (1<<30)
#define LINF (lint)(1LL<<56)
#define MINF (lint)(2e18)
#define endl "\n"
#define rep(i,n) for(lint (i)=0;(i)<(n);(i)++)
#define reprev(i,n) for(lint (i)=(n-1);(i)>=0;(i)--)
#define flc(x) __builtin_popcountll(x)
#define pint pair<int,int>
#define pdouble pair<double,double>
#define plint pair<lint,lint>
#define fi first
#define se second
#define all(x) x.begin(),x.end()
//#define vec vector<lint>
#define nep(x) next_permutation(all(x))
typedef long long lint;
int dx[8]={1,1,0,-1,-1,-1,0,1};
int dy[8]={0,1,1,1,0,-1,-1,-1};
const int MAX_N=3e5+5;
template<class T>bool chmax(T &a,const T &b){if(a<b){a=b;return 1;}return 0;}
template<class T>bool chmin(T &a,const T &b){if(b<a){a=b;return 1;}return 0;}
//vector<int> bucket[MAX_N/1000];
//constexpr int MOD=1000000007;
constexpr int MOD=998244353;
#include<atcoder/all>
using namespace atcoder;
typedef __int128_t llint;
using mint=modint998244353;

vector<int> rs;

//https://github.com/NachiaVivias/cp-library
//https://judge.yosupo.jp/submission/77455

struct SplayTreeByIdx{
  struct S{mint dp1,dp2,dp1key,dp2key;};
  static S op(S l, S r){ return {l.dp1+r.dp1,l.dp2+r.dp2,l.dp1key+r.dp1key,l.dp2key+r.dp2key}; }
  static S e(){ return {0,0,0,0}; }
  struct F{mint mt00,mt01,mt10,mt11; };
  static S mapping(F f, S x){ return { f.mt00*x.dp1+f.mt01*x.dp2,f.mt10*x.dp1+f.mt11*x.dp2,f.mt00*x.dp1key+f.mt01*x.dp2key,f.mt10*x.dp1key+f.mt11*x.dp2key};}
  static F composition(F f, F x){ return { f.mt00*x.mt00+f.mt01*x.mt10,f.mt00*x.mt01+f.mt01*x.mt11,f.mt10*x.mt00+f.mt11*x.mt10,f.mt10*x.mt01+f.mt11*x.mt11}; }
  static F id(){ return {1,0,0,1}; }
  struct Node{
    Node *l = 0, *r = 0, *p = 0;
    S a = e();
    S prod = e();
    F f = id();
    bool propagated = true;
    int z = 0;
    int sumz = 0;
    int rev = 0;
  };
  Node *NIL = nullptr;
  Node *R;
  void prepareDown(Node* c){
    if(!c->propagated){
      if(c->l != NIL){
        c->l->a = mapping(c->f, c->l->a);
        c->l->prod = mapping(c->f, c->l->prod);
        c->l->f = composition(c->f, c->l->f);
        c->l->propagated = false;
      }
      if(c->r != NIL){
        c->r->a = mapping(c->f, c->r->a);
        c->r->prod = mapping(c->f, c->r->prod);
        c->r->f = composition(c->f, c->r->f);
        c->r->propagated = false;
      }
      c->f = id();
      c->propagated = true;
    }
    if(c->rev){
      swap(c->l, c->r);
      if(c->l != NIL) c->l->rev ^= 1;
      if(c->r != NIL) c->r->rev ^= 1;
      c->rev = 0;
    }
  }
  void prepareUp(Node* c){
    c->sumz = c->l->sumz + c->r->sumz + 1;
    c->prod = op(op(c->l->prod,c->a),c->r->prod);
  }
  SplayTreeByIdx(){
    if(!NIL){
      NIL = new Node();
      NIL->l = NIL->r = NIL->p = NIL;
      R = NIL;
    }
  }
  Node*& parentchild(Node* p){
    if(p->p == NIL) return R;
    if(p->p->l == p) return p->p->l;
    else return p->p->r;
  }
  void rotL(Node* c){
    Node* p = c->p;
    parentchild(p) = c;
    c->p = p->p;
    p->p = c;
    if(c->l != NIL) c->l->p = p;
    p->r = c->l;
    c->l = p;
  }
  void rotR(Node* c){
    Node* p = c->p;
    parentchild(p) = c;
    c->p = p->p;
    p->p = c;
    if(c->r != NIL) c->r->p = p;
    p->l = c->r;
    c->r = p;
  }
  void splay(Node* c){
    while(c->p != NIL){
      Node* p = c->p;
      Node* pp = p->p;
      if(p->l == c){
        if(pp == NIL){ rotR(c); }
        else if(pp->l == p){ rotR(p); rotR(c); }
        else if(pp->r == p){ rotR(c); rotL(c); }
      }
      else{
        if(pp == NIL){ rotL(c); }
        else if(pp->r == p){ rotL(p); rotL(c); }
        else if(pp->l == p){ rotL(c); rotR(c); }
      }
      if(pp != NIL) prepareUp(pp);
      if(p != NIL) prepareUp(p);
    }
    prepareUp(c);
  }
  Node* kth_element(int k){
    if(k >= R->sumz) return NIL;
    Node* c = R;
    while(true){
      prepareDown(c);
      auto cl = c->l;
      if(cl->sumz == k) break;
      if(cl->sumz > k){ c = cl; continue; }
      k -= cl->sumz + 1;
      c = c->r;
    }
    prepareDown(c);
    splay(c);
    return c;
  }
  Node* insert_at(int k, S x){
    Node* nx = new Node(*NIL);
    nx->z = nx->sumz = 1;
    nx->a = nx->prod = x;
    if(k == R->sumz){
      nx->l = R;
      if(R != NIL) R->p = nx;
      R = nx;
      prepareUp(nx);
      return nx;
    }
    auto p = kth_element(k);
    nx->l = p->l;
    nx->r = p;
    R = nx;
    if(p->l != NIL){
      prepareDown(p->l);
      p->l->p = nx;
    }
    p->p = nx;
    p->l = NIL;
    prepareUp(p);
    prepareUp(nx);
    return nx;
  }
  void erase_at(int k){
    if(k >= R->sumz) return;
    auto toerase = NIL;
    if(k == 0){
      kth_element(0);
      prepareDown(R);
      toerase = R;
      R = R->r;
      if(R != NIL) R->p = NIL;
    }
    else{
      kth_element(k-1);
      prepareDown(R);
      auto c = R->r;
      prepareDown(c);
      while(c->l != NIL){
        c = c->l;
        prepareDown(c);
      }
      auto p = c->p;
      toerase = c;
      parentchild(c) = c->r;
      if(c->r != NIL) c->r->p = p;
      splay(p);
      //delete p;
    }
  }
  Node* between(int l, int r){
    if(l >= r) return NIL;
    if(l == 0 && r == R->sumz) return R;
    if(l == 0) return kth_element(r)->l;
    if(r == R->sumz) return kth_element(l-1)->r;
    auto lp = kth_element(l-1);
    auto rp = kth_element(r);
    while(rp->l != lp){
      auto p = lp->p;
      prepareDown(p);
      prepareDown(lp);
      if(p->l == lp) rotR(lp);
      else rotL(lp);
      prepareUp(p);
      prepareUp(lp);
    }
    return lp->r;
  }
  void reverse(int l, int r){
    if(l >= r) return;
    auto c = between(l,r);
    c->rev ^= 1;
    prepareDown(c);
    splay(c);
  }
  void apply(int l, int r, F f){
    if(l >= r) return;
    auto c = between(l,r);
    c->a = mapping(f,c->a);
    c->prod = mapping(f,c->prod);
    c->f = composition(f,c->f);
    c->propagated = false;
    prepareDown(c);
    splay(c);
  }
  S prod(int l, int r){
    if(l >= r) return e();
    return between(l,r)->prod;
  }
  void del(Node* tmp){
      if(tmp->l!=NIL) del(tmp->l);
      if(tmp->r!=NIL) del(tmp->r);
      delete tmp;
  }
  void clear(){
      while(R->p!=NIL) R = R->p;
      del(R);
  }
};

template<typename U = unsigned, int B = 32>
class lazy_binary_trie {
    struct node {
        int cnt;
        U lazy;
        node *ch[2];
        node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {}
    };
    void push(node* t, int b) {
        if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]);
        if (t->ch[0]) t->ch[0]->lazy ^= t->lazy;
        if (t->ch[1]) t->ch[1]->lazy ^= t->lazy;
        t->lazy = 0;
    }
    node* add(node* t, U val, int b = B - 1) {
        if (!t) t = new node;
        t->cnt += 1;
        if (b < 0) return t;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = add(t->ch[f], val, b - 1);
        return t;
    }
    node* sub(node* t, U val, int b = B - 1) {
        assert(t);
        t->cnt -= 1;
        if (t->cnt == 0){
            delete t;
            return nullptr;
        }
        if (b < 0) return t;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = sub(t->ch[f], val, b - 1);
        return t;
    }
    U get_min(node* t, U val, int b = B - 1) {
        assert(t);
        if (b < 0) return 0;
        push(t, b);
        bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f];
        return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b);
    }
    U get(node* t, int k, int b = B - 1) {
        if (b < 0) return 0;
        push(t, b);
        int m = t->ch[0] ? t->ch[0]->cnt : 0;
        return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b);
    }
    int count_lower(node* t, U val, int b = B - 1) {
        if (!t || b < 0) return 0;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1);
    }
    node *root;
public:
    lazy_binary_trie() : root(nullptr) {}
    int size() const {
        return root ? root->cnt : 0;
    }
    bool empty() const {
        return !root;
    }
    void insert(U val) {
        root = add(root, val);
    }
    void erase(U val) {
        root = sub(root, val);
    }
    void xor_all(U val) {
        if (root) root->lazy ^= val;
    }
    U max_element(U bias = 0) {
        return get_min(root, ~bias);
    }
    U min_element(U bias = 0) {
        return get_min(root, bias);
    }
    int lower_bound(U val) { // return id
        return count_lower(root, val);
    }
    int upper_bound(U val) { // return id
        return count_lower(root, val + 1);
    }
    U operator[](int k) {
        assert(0 <= k && k < size());
        return get(root, k);
    }
    U at(int k) {
        assert(0 <= k && k < size());
        return get(root, k);
    }
    int count(U val) {
        if (!root) return 0;
        node *t = root;
        for (int i = B - 1; i >= 0; i--) {
            push(t, i);
            t = t->ch[(val >> (U)i) & (U)1];
            if (!t) return 0;
        }
        return t->cnt;
    }
    void del(node *t){
        if(t->ch[0]!=nullptr) del(t->ch[0]);
        if(t->ch[1]!=nullptr) del(t->ch[1]);
        delete t;
    }
    void clear(){
        node *t=root;
        if(t->ch[0]!=nullptr) del(t->ch[0]);
        if(t->ch[1]!=nullptr) del(t->ch[1]);
        delete t;
    }
};

SplayTreeByIdx dp[100005];
vector<SplayTreeByIdx*> dp_pt;

mint ans=0;
int N;
vector<int> edge[100005];
int sub_size[100005];
mint pow2[100005];
lint A[100005];
int pt[100005];
vector<lazy_binary_trie<int,17>> bt(100000);
vector<lazy_binary_trie<int,17>*> bt_pt(100000); 

bool merge(int left,int right){
    int rev=false;
    swap(dp_pt[left],dp_pt[right]),swap(bt_pt[left],bt_pt[right]),rev=true;
    vector<SplayTreeByIdx::S> adds;
    int N=bt_pt[right]->size();
    rep(i,N){ //key未満
        SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1);
        int key=bt_pt[right]->at(i);
        int hi=bt_pt[left]->lower_bound(key);
        SplayTreeByIdx::S lo_sum=dp_pt[left]->prod(0,hi);
        mint dp1=now.dp1*lo_sum.dp1;
        mint dp2=now.dp1*lo_sum.dp2+now.dp2*lo_sum.dp1;
        mint dp1_key=dp1*rs[key];
        mint dp2_key=dp2*rs[key];
        adds.push_back({dp1,dp2,dp1_key,dp2_key});
    }
    int before=0;
    mint sum[2];
    rep(i,N){
        SplayTreeByIdx::S now=dp_pt[right]->prod(i,i+1);
        int key=bt_pt[right]->at(i);
        int hi=bt_pt[left]->lower_bound(key);
        if(before!=hi){
            dp_pt[left]->apply(before,hi,{sum[0],0,sum[1],sum[0]});
        }
        sum[0]+=now.dp1,sum[1]+=now.dp2;
        before=hi;
    }
    if(before!=bt_pt[left]->size()){
        dp_pt[left]->apply(before,bt_pt[left]->size(),{sum[0],0,sum[1],sum[0]});
    }
    int idx=0;
    for(auto e:adds){
        int key=bt_pt[right]->at(idx);
        int hi=bt_pt[left]->lower_bound(key);
        if(bt_pt[left]->count(key)){
            SplayTreeByIdx::S res=dp_pt[left]->prod(hi,hi+1);
            res.dp1+=e.dp1,res.dp2+=e.dp2,res.dp1key+=e.dp1key,res.dp2key+=e.dp2key;
            dp_pt[left]->erase_at(hi);
            dp_pt[left]->insert_at(hi,res);
        }
        else{
            dp_pt[left]->insert_at(hi,e);
            bt_pt[left]->insert(key);
        }
        idx++;
    }
    return rev;
}

void dfs(int now,int par){
    for(auto e:edge[now]){
        if(e==par) continue;
        dfs(e,now);
        sub_size[now]+=sub_size[e];
        if(bt_pt[e]->count(0)){
            SplayTreeByIdx::S zero=dp_pt[e]->prod(0,1);
            dp_pt[e]->erase_at(0);
            zero.dp1+=pow2[sub_size[e]-1];
            dp_pt[e]->insert_at(0,zero);
        }
        else{
            dp_pt[e]->insert_at(0,{pow2[sub_size[e]-1],0,0,0});
            bt_pt[e]->insert(0);
        }
        merge(now,e);
        bt_pt[e]->clear();
        dp_pt[e]->clear();
    }
    ans+=dp_pt[now]->prod(0,bt_pt[now]->size()).dp2key*pow2[N-1-sub_size[now]+(now==0)];
}

int main(void){
    cin >> N;
    rep(i,N) cin >> A[i];
    vector<int> as;
    rep(i,N) as.push_back(A[i]);
    sort(all(as));
    rs.push_back(0);
    int pre=-1;
    rep(i,N) if(pre!=as[i]) rs.push_back(as[i]),pre=as[i];
    rep(i,N) sub_size[i]=1;
    pow2[0]=1;
    rep(i,100004) pow2[i+1]=pow2[i]*2;
    rep(i,100004) pt[i]=i;
    rep(i,N) bt_pt[i]=&bt[i];
    rep(i,N){
        dp[i].insert_at(0,{1,1,A[i],A[i]});
        bt_pt[i]->insert(lower_bound(all(rs),A[i])-rs.begin());
        dp_pt.push_back(&dp[i]);
    }
    rep(i,N-1){
        int u,v;
        cin >> u >> v;
        u--,v--;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    dfs(0,-1);
    cout << ans.val() << endl;
}
0