結果

問題 No.2115 Making Forest Easy
ユーザー noshi91noshi91
提出日時 2022-10-28 23:10:55
言語 C++17
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 14 ms / 2,000 ms
コード長 5,554 bytes
コンパイル時間 2,534 ms
コンパイル使用メモリ 152,816 KB
実行使用メモリ 4,384 KB
最終ジャッジ日時 2023-09-20 06:24:35
合計ジャッジ時間 5,072 ms
ジャッジサーバーID
(参考情報)
judge14 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,376 KB
testcase_01 AC 1 ms
4,380 KB
testcase_02 AC 12 ms
4,376 KB
testcase_03 AC 14 ms
4,376 KB
testcase_04 AC 13 ms
4,380 KB
testcase_05 AC 10 ms
4,380 KB
testcase_06 AC 13 ms
4,376 KB
testcase_07 AC 12 ms
4,380 KB
testcase_08 AC 12 ms
4,380 KB
testcase_09 AC 12 ms
4,376 KB
testcase_10 AC 12 ms
4,380 KB
testcase_11 AC 13 ms
4,376 KB
testcase_12 AC 12 ms
4,376 KB
testcase_13 AC 11 ms
4,376 KB
testcase_14 AC 12 ms
4,376 KB
testcase_15 AC 12 ms
4,376 KB
testcase_16 AC 13 ms
4,380 KB
testcase_17 AC 3 ms
4,376 KB
testcase_18 AC 3 ms
4,376 KB
testcase_19 AC 7 ms
4,384 KB
testcase_20 AC 10 ms
4,380 KB
testcase_21 AC 3 ms
4,380 KB
testcase_22 AC 4 ms
4,376 KB
testcase_23 AC 10 ms
4,380 KB
testcase_24 AC 13 ms
4,376 KB
testcase_25 AC 7 ms
4,380 KB
testcase_26 AC 12 ms
4,380 KB
testcase_27 AC 8 ms
4,380 KB
testcase_28 AC 12 ms
4,380 KB
testcase_29 AC 3 ms
4,380 KB
testcase_30 AC 12 ms
4,376 KB
testcase_31 AC 12 ms
4,380 KB
testcase_32 AC 12 ms
4,376 KB
testcase_33 AC 12 ms
4,376 KB
testcase_34 AC 2 ms
4,380 KB
testcase_35 AC 4 ms
4,376 KB
testcase_36 AC 13 ms
4,376 KB
testcase_37 AC 7 ms
4,380 KB
testcase_38 AC 4 ms
4,380 KB
testcase_39 AC 6 ms
4,376 KB
testcase_40 AC 12 ms
4,376 KB
testcase_41 AC 12 ms
4,376 KB
testcase_42 AC 12 ms
4,376 KB
testcase_43 AC 7 ms
4,380 KB
testcase_44 AC 4 ms
4,376 KB
testcase_45 AC 10 ms
4,380 KB
testcase_46 AC 13 ms
4,376 KB
testcase_47 AC 6 ms
4,376 KB
testcase_48 AC 13 ms
4,380 KB
testcase_49 AC 12 ms
4,380 KB
testcase_50 AC 8 ms
4,376 KB
testcase_51 AC 6 ms
4,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// たすけてくれ

template <unsigned m> struct modint {
  unsigned v;
  modint(long long x = 0) : v() {
    x %= m;
    if (x < 0)
      x += m;
    v = x;
  }
  modint operator+(modint r) const { return modint(*this) += r; }
  modint operator-(modint r) const { return modint(*this) -= r; }
  modint operator*(modint r) const { return modint(*this) *= r; }
  modint operator/(modint r) const { return modint(*this) /= r; }
  modint &operator+=(modint r) {
    v += r.v;
    if (v >= m)
      v -= m;
    return *this;
  }
  modint &operator-=(modint r) {
    if (v < r.v)
      v += m;
    v -= r.v;
    return *this;
  }
  modint &operator*=(modint r) {
    v = (unsigned long long)v * r.v % m;
    return *this;
  }
  modint &operator/=(modint r) {
    unsigned e = m - 2;
    while (e) {
      if (e & 1)
        *this *= r;
      r *= r;
      e >>= 1;
    }
    return *this;
  }
};

using mint = modint<998244353>;

struct qmi {
  mint c0, c1;
  qmi(mint c0_, mint c1_ = 0) : c0(c0_), c1(c1_) {}
  friend qmi operator+(const qmi &x, const qmi &y) {
    return {x.c0 + y.c0, x.c1 + y.c1};
  }
  friend qmi operator*(const qmi &x, const qmi &y) {
    return {x.c0 * y.c0, x.c1 * y.c0 + x.c0 * y.c1};
  }
  friend qmi operator*(const mint x, const qmi &y) {
    return {x * y.c0, x * y.c1};
  }

  qmi &operator*=(const qmi &y) { return *this = *this * y; }
  qmi &operator+=(const qmi &y) {
    c0 += y.c0;
    c1 += y.c1;
    return *this;
  }
};

#include <algorithm>
#include <cassert>
#include <iostream>
#include <utility>
#include <vector>

#define rep(i, a, b) for (int i = a; i < int(b); i++)

#include <random>

std::mt19937_64 rng(0xE19E71EBC6129513ULL);
using pri_t = typename std::mt19937_64::result_type;

#include <memory>

struct node_t;
using ptr_t = std::unique_ptr<node_t>;

struct node_t {
  ptr_t l, r;
  pri_t pri;
  qmi val, sum, psum, mul;
  int A;

  node_t(int A_, qmi val_)
      : l(), r(), pri(rng()), val(val_), sum(val_), psum(mint(A_) * val_),
        mul(1, 0), A(A_) {}
};

qmi sum(const ptr_t &p) { return p ? p->mul * p->sum : qmi(0); }
qmi psum(const ptr_t &p) { return p ? p->mul * p->psum : qmi(0); }
pri_t pri(const ptr_t &p) { return p ? p->pri : 0; }
void mult(ptr_t &p, const qmi v) {
  if (p)
    p->mul *= v;
}

void push(node_t &n) {
  n.val *= n.mul;
  n.sum *= n.mul;
  n.psum *= n.mul;
  mult(n.l, n.mul);
  mult(n.r, n.mul);
  n.mul = qmi(1, 0);
}

void fix(node_t &n) {
  n.sum = sum(n.l) + n.val + sum(n.r);
  n.psum = psum(n.l) + mint(n.A) * n.val + psum(n.r);
}

std::tuple<ptr_t, ptr_t, ptr_t> split(ptr_t p, const int A) {
  if (!p)
    return {nullptr, nullptr, nullptr};
  push(*p);
  if (p->A < A) {
    auto [l, m, r] = split(std::move(p->r), A);
    p->r = std::move(l);
    fix(*p);
    return {std::move(p), std::move(m), std::move(r)};
  } else if (p->A > A) {
    auto [l, m, r] = split(std::move(p->l), A);
    p->l = std::move(r);
    fix(*p);
    return {std::move(l), std::move(m), std::move(p)};
  } else {
    auto l = std::move(p->l);
    auto r = std::move(p->r);
    fix(*p);
    return {std::move(l), std::move(p), std::move(r)};
  }
}

ptr_t conv_(ptr_t x, qmi xs, ptr_t y, qmi ys) {
  if (pri(x) < pri(y)) {
    std::swap(x, y);
    std::swap(xs, ys);
  }
  if (!y) {
    mult(x, ys);
    return x;
  }
  push(*x);
  auto [l, m, r] = split(std::move(y), x->A);
  {
    qmi xs_ = xs, ys_ = ys;
    xs += sum(x->l);
    ys += sum(l);
    x->l = conv_(std::move(x->l), xs_, std::move(l), ys_);
  }
  xs += x->val;
  x->val *= ys;
  if (m) {
    x->val += m->val * xs;
    ys += m->val;
  }
  x->r = conv_(std::move(x->r), xs, std::move(r), ys);
  fix(*x);
  return x;
}

ptr_t conv(ptr_t x, ptr_t y) {
  return conv_(std::move(x), qmi(0), std::move(y), qmi(0));
}

void insert_(ptr_t &p, ptr_t e) {
  if (pri(p) < e->pri) {
    auto [l, m, r] = split(std::move(p), e->A);
    if (m)
      e->val += m->val;
    e->l = std::move(l);
    e->r = std::move(r);
    fix(*e);
    p = std::move(e);
    return;
  }
  push(*p);
  if (p->A < e->A) {
    insert_(p->r, std::move(e));
  } else if (p->A > e->A) {
    insert_(p->l, std::move(e));
  } else {
    p->val += e->val;
  }
  fix(*p);
}

void insert(ptr_t &p, const int A, const qmi val) {
  insert_(p, std::make_unique<node_t>(A, val));
}

#include <stack>

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  int N;
  std::cin >> N;
  std::vector<int> A(N);
  for (auto &e : A)
    std::cin >> e;
  std::vector<std::vector<int>> g(N);
  rep(i, 0, N - 1) {
    int u, v;
    std::cin >> u >> v;
    u--;
    v--;
    g[u].push_back(v);
    g[v].push_back(u);
  }

  qmi ans = qmi(0);
  const mint half = mint(1) / 2;

  std::vector<int> parent(N, -1);
  std::vector<ptr_t> dp(N);
  std::vector<int> dfs;
  {
    std::stack<int> st;
    st.push(0);
    while (!st.empty()) {
      const int v = st.top();
      st.pop();
      dfs.push_back(v);
      for (const int u : g[v]) {
        if (u != parent[v]) {
          st.push(u);
          parent[u] = v;
        }
      }
    }
  }

  std::reverse(dfs.begin(), dfs.end());
  for (const int v : dfs) {
    ptr_t ret = std::make_unique<node_t>(A[v], qmi(1, 1));
    for (const int u : g[v]) {
      if (u == parent[v])
        continue;
      auto t = std::move(dp[u]);
      ans += half * psum(t);
      mult(t, qmi(half));
      insert(t, 0, qmi(half));
      ret = conv(std::move(ret), std::move(t));
    }
    dp[v] = std::move(ret);
  }

  ans += psum(dp[0]);

  rep(i, 0, N - 1) ans = 2 * ans;

  std::cout << ans.c1.v << "\n";

  return 0;
}
0