結果
| 問題 | No.2116 Making Forest Hard |
| コンテスト | |
| ユーザー |
bayashiko
|
| 提出日時 | 2022-10-10 05:30:00 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 13,816 bytes |
| コンパイル時間 | 6,789 ms |
| コンパイル使用メモリ | 305,852 KB |
| 最終ジャッジ日時 | 2025-02-08 01:07:42 |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | MLE * 2 |
| other | AC * 19 TLE * 21 MLE * 13 |
コンパイルメッセージ
In member function ‘RedBlackTree<std::vector<atcoder::static_modint<998244353> >, std::vector<atcoder::static_modint<998244353> >, dd, dl, ll, none>::Node& RedBlackTree<std::vector<atcoder::static_modint<998244353> >, std::vector<atcoder::static_modint<998244353> >, dd, dl, ll, none>::Node::operator=(RedBlackTree<std::vector<atcoder::static_modint<998244353> >, std::vector<atcoder::static_modint<998244353> >, dd, dl, ll, none>::Node&&)’,
inlined from ‘RedBlackTree<D, L, f, g, h, p>::Node* RedBlackTree<D, L, f, g, h, p>::alloc(Node*, Node*) [with D = std::vector<atcoder::static_modint<998244353> >; L = std::vector<atcoder::static_modint<998244353> >; D (* f)(D, D) = dd; D (* g)(D, L) = dl; L (* h)(L, L) = ll; L (* p)(L, int) = none]’ at main.cpp:102:30,
inlined from ‘RedBlackTree<D, L, f, g, h, p>::Node* RedBlackTree<D, L, f, g, h, p>::submerge(Node*, Node*) [with D = std::vector<atcoder::static_modint<998244353> >; L = std::vector<atcoder::static_modint<998244353> >; D (* f)(D, D) = dd; D (* g)(D, L) = dl; L (* h)(L, L) = ll; L (* p)(L, int) = none]’ at main.cpp:180:17:
main.cpp:71:10: warning: ‘*(__vector(2) int*)((char*)&<unnamed> + offsetof(RedBlackTree<std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >, std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >, &dd(std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >, std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >), &dl(std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >, std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >), &ll(std::vector<atcoder::static_modint<998244353, 0>, std::allocator<atcoder::static_modint<998244353, 0> > >, std::vector<atcoder::static_modint<99824
ソースコード
#if defined(LOCAL)
#include<stdc++.h>
#else
#include<bits/stdc++.h>
#endif
#include<random>
#pragma GCC optimize("Ofast")
//#pragma GCC target("avx2")
#pragma GCC optimize("unroll-loops")
using namespace std;
//#include<boost/multiprecision/cpp_int.hpp>
//#include<boost/multiprecision/cpp_dec_float.hpp>
//namespace mp=boost::multiprecision;
//#define mulint mp::cpp_int
//#define mulfloat mp::cpp_dec_float_100
struct __INIT{__INIT(){cin.tie(0);ios::sync_with_stdio(false);cout<<fixed<<setprecision(15);}} __init;
//#define INF (1<<30)
#define LINF (lint)(1LL<<56)
#define MINF (lint)(2e18)
#define endl "\n"
#define rep(i,n) for(lint (i)=0;(i)<(n);(i)++)
#define reprev(i,n) for(lint (i)=(n-1);(i)>=0;(i)--)
#define flc(x) __builtin_popcountll(x)
#define pint pair<int,int>
#define pdouble pair<double,double>
#define plint pair<lint,lint>
#define fi first
#define se second
#define all(x) x.begin(),x.end()
//#define vec vector<lint>
#define nep(x) next_permutation(all(x))
typedef long long lint;
int dx[8]={1,1,0,-1,-1,-1,0,1};
int dy[8]={0,1,1,1,0,-1,-1,-1};
const int MAX_N=3e5+5;
template<class T>bool chmax(T &a,const T &b){if(a<b){a=b;return 1;}return 0;}
template<class T>bool chmin(T &a,const T &b){if(b<a){a=b;return 1;}return 0;}
//vector<int> bucket[MAX_N/1000];
//constexpr int MOD=1000000007;
constexpr int MOD=998244353;
#include<atcoder/all>
using namespace atcoder;
typedef __int128_t llint;
//https://ei1333.github.io/luzhiled/snippets/structure/red-black-tree.html
template< class T >
struct ArrayPool {
vector< T > pool;
vector< T * > stock;
int ptr;
ArrayPool(int sz) : pool(sz), stock(sz) {}
inline T *alloc() { return stock[--ptr]; }
inline void free(T *t) { stock[ptr++] = t; }
void clear() {
ptr = (int) pool.size();
for(int i = 0; i < pool.size(); i++) stock[i] = &pool[i];
}
};
template< class D, class L, D (*f)(D, D), D (*g)(D, L), L (*h)(L, L), L (*p)(L, int) >
struct RedBlackTree {
enum COLOR {
BLACK, RED
};
struct Node {
Node *l, *r;
COLOR color;
int level, cnt;
D key, sum;
L lazy;
Node() {}
Node(const D &k, const L &laz) :
key(k), sum(k), l(nullptr), r(nullptr), color(BLACK), level(0), cnt(1), lazy(laz) {}
Node(Node *l, Node *r, const D &k, const L &laz) :
key(k), color(RED), l(l), r(r), lazy(laz) {}
};
ArrayPool< Node > pool;
const D M1;
const L OM0;
RedBlackTree(int sz, const D &M1, const L &OM0) :
pool(sz), M1(M1), OM0(OM0) { pool.clear(); }
inline Node *alloc(const D &key) {
return &(*pool.alloc() = Node(key, OM0));
}
inline Node *alloc(Node *l, Node *r) {
auto t = &(*pool.alloc() = Node(l, r, M1, OM0));
return update(t);
}
virtual Node *clone(Node *t) { return t; }
inline int count(const Node *t) { return t ? t->cnt : 0; }
inline D sum(const Node *t) { return t ? t->sum : M1; }
Node *update(Node *t) {
t->cnt = count(t->l) + count(t->r) + (!t->l || !t->r);
t->level = t->l ? t->l->level + (t->l->color == BLACK) : 0;
t->sum = f(f(sum(t->l), t->key), sum(t->r));
return t;
}
Node *propagate(Node *t) {
t = clone(t);
if(t->lazy != OM0) {
if(!t->l) {
t->key = g(t->key, p(t->lazy, 1));
} else {
if(t->l) {
t->l = clone(t->l);
t->l->lazy = h(t->l->lazy, t->lazy);
t->l->sum = g(t->l->sum, p(t->lazy, count(t->l)));
}
if(t->r) {
t->r = clone(t->r);
t->r->lazy = h(t->r->lazy, t->lazy);
t->r->sum = g(t->r->sum, p(t->lazy, count(t->r)));
}
}
t->lazy = OM0;
}
return update(t);
}
Node *rotate(Node *t, bool b) {
t = propagate(t);
Node *s;
if(b) {
s = propagate(t->l);
t->l = s->r;
s->r = t;
} else {
s = propagate(t->r);
t->r = s->l;
s->l = t;
}
update(t);
return update(s);
}
Node *submerge(Node *l, Node *r) {
if(l->level < r->level) {
r = propagate(r);
Node *c = (r->l = submerge(l, r->l));
if(r->color == BLACK && c->color == RED && c->l && c->l->color == RED) {
r->color = RED;
c->color = BLACK;
if(r->r->color == BLACK) return rotate(r, true);
r->r->color = BLACK;
}
return update(r);
}
if(l->level > r->level) {
l = propagate(l);
Node *c = (l->r = submerge(l->r, r));
if(l->color == BLACK && c->color == RED && c->r && c->r->color == RED) {
l->color = RED;
c->color = BLACK;
if(l->l->color == BLACK) return rotate(l, false);
l->l->color = BLACK;
}
return update(l);
}
return alloc(l, r);
}
Node *merge(Node *l, Node *r) {
if(!l || !r) return l ? l : r;
Node *c = submerge(l, r);
c->color = BLACK;
return c;
}
pair< Node *, Node * > split(Node *t, int k) {
if(!t) return {nullptr, nullptr};
t = propagate(t);
if(k == 0) return {nullptr, t};
if(k >= count(t)) return {t, nullptr};
Node *l = t->l, *r = t->r;
pool.free(t);
if(k < count(l)) {
auto pp = split(l, k);
return {pp.first, merge(pp.second, r)};
}
if(k > count(l)) {
auto pp = split(r, k - count(l));
return {merge(l, pp.first), pp.second};
}
return {l, r};
}
Node *build(int l, int r, const vector< D > &v) {
if(l + 1 >= r) return alloc(v[l]);
return merge(build(l, (l + r) >> 1, v), build((l + r) >> 1, r, v));
}
Node *build(const vector< D > &v) {
//pool.clear();
return build(0, (int) v.size(), v);
}
void dump(Node *r, typename vector< D >::iterator &it, L lazy) {
if(r->lazy != OM0) lazy = h(lazy, r->lazy);
if(!r->l || !r->r) {
*it++ = g(r->key, lazy);
return;
}
dump(r->l, it, lazy);
dump(r->r, it, lazy);
}
vector< D > dump(Node *r) {
vector< D > v((size_t) count(r));
auto it = begin(v);
dump(r, it, OM0);
return v;
}
string to_string(Node *r) {
auto s = dump(r);
string ret;
for(int i = 0; i < s.size(); i++) {
ret += std::to_string(s[i]);
ret += ", ";
}
return (ret);
}
void insert(Node *&t, int k, const D &v) {
auto x = split(t, k);
t = merge(merge(x.first, alloc(v)), x.second);
}
D erase(Node *&t, int k) {
auto x = split(t, k);
auto y = split(x.second, 1);
auto v = y.first->key;
pool.free(y.first);
t = merge(x.first, y.second);
return v;
}
D query(Node *&t, int a, int b) {
auto x = split(t, a);
auto y = split(x.second, b - a);
auto ret = sum(y.first);
t = merge(x.first, merge(y.first, y.second));
return ret;
}
void set_propagate(Node *&t, int a, int b, const L &pp) {
auto x = split(t, a);
auto y = split(x.second, b - a);
y.first->lazy = h(y.first->lazy, pp);
t = merge(x.first, merge(propagate(y.first), y.second));
}
void set_element(Node *&t, int k, const D &x) {
if(!t->l) {
t->key = t->sum = x;
return;
}
t = propagate(t);
if(k < count(t->l)) set_element(t->l, k, x);
else set_element(t->r, k - count(t->l), x);
t = update(t);
}
int size(Node *t) {
return count(t);
}
bool empty(Node *t) {
return !t;
}
Node *makeset() {
return (nullptr);
}
};
using mint=modint998244353;
inline vector<mint> dd(vector<mint> a,vector<mint> b){
vector<mint> ret(5);
rep(i,5) ret[i]=a[i]+b[i];
return ret;
}
inline vector<mint> dl(vector<mint> a,vector<mint> b){
vector<mint> ret(5);
ret[0]=a[0]*b[0]+a[1]*b[1];
ret[1]=a[0]*b[2]+a[1]*b[3];
ret[2]=a[2]*b[0]+a[3]*b[1];
ret[3]=a[2]*b[2]+a[3]*b[3];
ret[4]=a[4];
return ret; //D
}
inline vector<mint> ll(vector<mint> a,vector<mint> b){
vector<mint> ret(4);
ret[0]=a[0]*b[0]+a[1]*b[2];
ret[1]=a[0]*b[1]+a[1]*b[3];
ret[2]=a[2]*b[0]+a[3]*b[2];
ret[3]=a[2]*b[1]+a[3]*b[3];
return ret;
}
inline vector<mint> none(vector<mint> a,int b){
return a;
}
template<typename U = unsigned, int B = 32>
class lazy_binary_trie {
struct node {
int cnt;
U lazy;
node *ch[2];
node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {}
};
void push(node* t, int b) {
if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]);
if (t->ch[0]) t->ch[0]->lazy ^= t->lazy;
if (t->ch[1]) t->ch[1]->lazy ^= t->lazy;
t->lazy = 0;
}
node* add(node* t, U val, int b = B - 1) {
if (!t) t = new node;
t->cnt += 1;
if (b < 0) return t;
push(t, b);
bool f = (val >> (U)b) & (U)1;
t->ch[f] = add(t->ch[f], val, b - 1);
return t;
}
node* sub(node* t, U val, int b = B - 1) {
assert(t);
t->cnt -= 1;
if (t->cnt == 0) return nullptr;
if (b < 0) return t;
push(t, b);
bool f = (val >> (U)b) & (U)1;
t->ch[f] = sub(t->ch[f], val, b - 1);
return t;
}
U get_min(node* t, U val, int b = B - 1) {
assert(t);
if (b < 0) return 0;
push(t, b);
bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f];
return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b);
}
U get(node* t, int k, int b = B - 1) {
if (b < 0) return 0;
push(t, b);
int m = t->ch[0] ? t->ch[0]->cnt : 0;
return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b);
}
int count_lower(node* t, U val, int b = B - 1) {
if (!t || b < 0) return 0;
push(t, b);
bool f = (val >> (U)b) & (U)1;
return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1);
}
node *root;
public:
lazy_binary_trie() : root(nullptr) {}
int size() const {
return root ? root->cnt : 0;
}
bool empty() const {
return !root;
}
void insert(U val) {
root = add(root, val);
}
void erase(U val) {
root = sub(root, val);
}
void xor_all(U val) {
if (root) root->lazy ^= val;
}
U max_element(U bias = 0) {
return get_min(root, ~bias);
}
U min_element(U bias = 0) {
return get_min(root, bias);
}
int lower_bound(U val) { // return id
return count_lower(root, val);
}
int upper_bound(U val) { // return id
return count_lower(root, val + 1);
}
U operator[](int k) {
assert(0 <= k && k < size());
return get(root, k);
}
int count(U val) {
if (!root) return 0;
node *t = root;
for (int i = B - 1; i >= 0; i--) {
push(t, i);
t = t->ch[(val >> (U)i) & (U)1];
if (!t) return 0;
}
return t->cnt;
}
};
using T = RedBlackTree<vector<mint>,vector<mint>,dd,dl,ll,none>;
T RBT(2000000,vector<mint>{0,0,0,0,0},vector<mint>{1,0,0,1});
vector<T::Node*> t;
mint ans=0;
int N;
vector<int> edge[100005];
int sub_size[100005];
mint pow2[100005];
lint A[100005];
int pt[100005];
vector<lazy_binary_trie<int,30>> bt(100000);
vector<lazy_binary_trie<int,30>*> bt_pt(100000);
bool merge(int left,int right){
int rev=false;
if(RBT.size(t[left])<RBT.size(t[right])) swap(t[left],t[right]),swap(bt_pt[left],bt_pt[right]),rev=true; //マージテク
vector<vector<mint>> adds;
int N=RBT.size(t[right]);
rep(i,N){ //key未満
vector<mint> now=RBT.query(t[right],i,i+1);
int key=now[4].val();
int hi=bt_pt[left]->lower_bound(key);
vector<mint> lo_sum=RBT.query(t[left],0,hi);
mint dp1=now[0]*lo_sum[0];
mint dp2=now[0]*lo_sum[1]+now[1]*lo_sum[0];
mint dp1_key=dp1*key;
mint dp2_key=dp2*key;
adds.push_back({dp1,dp2,dp1_key,dp2_key,key});
}
int before=0;
mint sum[2];
rep(i,N){
vector<mint> now=RBT.query(t[right],i,i+1);
int key=now[4].val();
int hi=bt_pt[left]->lower_bound(key);
if(before!=hi) RBT.set_propagate(t[left],before,hi,{sum[0],0,sum[1],sum[0]});
sum[0]+=now[0],sum[1]+=now[1];
before=hi;
}
if(before!=RBT.size(t[left])) RBT.set_propagate(t[left],before,RBT.size(t[left]),{sum[0],0,sum[1],sum[0]});
for(auto e:adds){
int key=e[4].val();
int hi=bt_pt[left]->lower_bound(key);
if(bt_pt[left]->count(e[4].val())){
vector<mint> res=RBT.query(t[left],hi,hi+1);
rep(i,4) res[i]+=e[i];
RBT.set_element(t[left],hi,res);
}
else{
RBT.insert(t[left],hi,e);
bt_pt[left]->insert(e[4].val());
}
}
return rev;
}
void dfs(int now,int par){
for(auto e:edge[now]){
if(e==par) continue;
dfs(e,now);
sub_size[now]+=sub_size[e];
if(bt_pt[e]->count(0)){
vector<mint> zero=RBT.query(t[e],0,1);
zero[0]+=pow2[sub_size[e]-1];
RBT.set_element(t[e],0,zero);
}
else{
RBT.insert(t[e],0,{pow2[sub_size[e]-1],0,0,0,0});
bt_pt[e]->insert(0);
}
merge(now,e);
}
ans+=RBT.query(t[pt[now]],0,RBT.size(t[pt[now]]))[3]*pow2[N-1-sub_size[now]+(now==0)];
}
int main(void){
cin >> N;
rep(i,N) cin >> A[i];
rep(i,N) sub_size[i]=1;
pow2[0]=1;
rep(i,100004) pow2[i+1]=pow2[i]*2;
rep(i,100004) pt[i]=i;
rep(i,N) bt_pt[i]=&bt[i];
rep(i,N){
T::Node* tt=RBT.makeset();
RBT.insert(tt,0,{1,1,A[i],A[i],A[i]});
t.push_back(tt);
bt_pt[i]->insert(A[i]);
}
rep(i,N-1){
int u,v;
cin >> u >> v;
u--,v--;
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(0,-1);
int sum1=0,sum2=0;
rep(i,N){
sum1+=RBT.size(t[i]);
sum2+=bt_pt[i]->size();
}
cout << ans.val() << endl;
}
bayashiko