結果

問題 No.2949 Product on Tree
ユーザー 00 Sakuda
提出日時 2025-01-02 00:13:49
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 396 ms / 2,000 ms
コード長 8,767 bytes
コンパイル時間 1,835 ms
コンパイル使用メモリ 140,268 KB
最終ジャッジ日時 2025-02-26 17:24:02
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 46
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <math.h>
#include <algorithm>
#include <set>
#include <map>
#include <unordered_map>
#include <queue>
#include <deque>
#include <stack>
#include <string>
#include <bitset>
#include <iomanip>
using namespace std;
using ll = long long;
using VVI = vector<vector<int>>;
using VVL = vector<vector<ll>>;
using VI = vector<int>;
using VL = vector<ll>;
using VS = vector<string>;
using VC = vector<char>;
using VP = vector<pair<int, int>>;
using Graph0 = vector<vector<int>>;
#define rep(i, n) for (int i = 0; i < (int)(n); i++)
#define drep(i, a, b) for (int i = (int)(a);i >= (int)(b);i--)
#define urep(i, a, b) for (int i = (int)(a);i <= (int)(b);i++)
#define lrep(i, n) for (ll i = 0; i < (ll)(n); i++)
#define ldrep(i, a, b) for (ll i = (ll)(a);i >= (ll)(b);i--)
#define lurep(i, a, b) for (ll i = (ll)(a);i <= (ll)(b);i++)
#define arep(i, v) for (auto i : v)
#define all(a) (a).begin(), (a).end()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define eyes cout << "Yes" << endl;exit(0);
#define eno cout << "No" << endl;exit(0);
template <typename T>
bool chmax(T &a, const T& b) {
  if (a < b) {
    a = b;  
    return true;
  }
  return false;
}
template <typename T>
bool chmin(T &a, const T& b) {
  if (a > b) {
    a = b; 
    return true;
  }
  return false;
}
template<typename T>
void excout(T A) {
  cout << A << endl;
  exit(0);
}
constexpr long long INF = (1LL << 60); // INFにちゅういい!


struct Edge
{
  int to;
  int cost;
};
using Graph = std::vector<std::vector<Edge>>;
using Pair = std::pair<long long, int>;
void Dijkstra(const Graph& graph, std::vector<long long>& distances, int startIndex)
{
  std::priority_queue<Pair, std::vector<Pair>, std::greater<Pair>> q;
  q.emplace((distances[startIndex] = 0), startIndex);

  while (!q.empty())
  {
    const long long distance = q.top().first;
    const int from = q.top().second;
    q.pop();
    if (distances[from] < distance)
    {
      continue;
    }
    for (const auto& edge : graph[from])
    {
      const long long d = (distances[from] + edge.cost);
      if (d < distances[edge.to])
      {
        q.emplace((distances[edge.to] = d), edge.to);
      }
    }
  }
}
template<typename T>
T MODS(T a, T mods) {
  return ((((((a + mods) % mods) + mods) % mods)));
}
VVL comb(int n, int r) {
  VVL v(n + 1, VL (n + 1, 0));
  for (int i = 0; i < v.size(); i++) {
    v[i][0] = 1;
    v[i][i] = 1;
  }
  for (int j = 1; j < v.size(); j++) {
    for (int k = 1; k < j; k++) {
      v[j][k] = (v[j - 1][k - 1] + v[j - 1][k]);
    }
  }
  return v;
}
vector<pair<long long, long long> > prime_factorize(long long N) {
    vector<pair<long long, long long> > res;
    for (long long p = 2; p * p <= N; ++p) {
        if (N % p != 0) {
            continue;
        }
        int e = 0;
        while (N % p == 0) {
            ++e;
            N /= p;
        }
        res.emplace_back(p, e);
    }
    if (N != 1) {
        res.emplace_back(N, 1);
    }
    return res;
}

struct UnionFind {
    vector<int> par, siz;
    UnionFind(int n) : par(n, -1), siz(n, 1) {}
    int root(int x) {
        if (par[x] == -1) return x;
        else return par[x] = root(par[x]);
    }
    bool issame(int x, int y) {
        return root(x) == root(y);
    }
    bool unite(int x, int y) {
        x = root(x);y = root(y);
        if (x == y) return false;
        if (siz[x] < siz[y]) swap(x, y);
        par[y] = x;
        siz[x] += siz[y];
        return true;
    }
    int size(int x) {
        return siz[root(x)];
    }
};
template<class Abel> struct WUnionFind {
    vector<int> par;
    vector<int> rank;
    vector<Abel> diff_weight;

    WUnionFind(int n = 1, Abel SUM_UNITY = 0) {
        init(n, SUM_UNITY);
    }

    void init(int n = 1, Abel SUM_UNITY = 0) {
        par.resize(n); rank.resize(n); diff_weight.resize(n);
        for (int i = 0; i < n; ++i) par[i] = i, rank[i] = 0, diff_weight[i] = SUM_UNITY;
    }

    int root(int x) {
        if (par[x] == x) {
            return x;
        }
        else {
            int r = root(par[x]);
            diff_weight[x] += diff_weight[par[x]];
            return par[x] = r;
        }
    }

    Abel weight(int x) {
        root(x);
        return diff_weight[x];
    }

    bool issame(int x, int y) {
        return root(x) == root(y);
    }

    bool merge(int x, int y, Abel w) { //W(y)=W(x)+w
        w += weight(x); w -= weight(y);
        x = root(x); y = root(y);
        if (x == y) return false;
        if (rank[x] < rank[y]) swap(x, y), w = -w;
        if (rank[x] == rank[y]) ++rank[x];
        par[y] = x;
        diff_weight[y] = w;
        return true;
    }

    Abel diff(int x, int y)  { // W(y) - W(x)
        return weight(y) - weight(x);
    }
};
VI topo_sort(Graph0& G) {
  int N = G.size();
  VI IND(N, 0);
  rep(v, N) {
    arep(nv, G[v]) {
      IND[nv]++;
    }
  }
  queue<int> que;
  rep(v, N) {
    if (IND[v] == 0) {
      que.push(v);
    }
  }
  VI ANS;
  while (!que.empty()) {
    int v = que.front();
    ANS.push_back(v);
    que.pop();
    arep(nv, G[v]) {
      IND[nv]--;
      if (IND[nv] == 0) {
        que.push(nv);
      }
    }
  }
  return ANS;
}
void ADD(int a, int b, Graph0& G) {
  G[a].push_back(b);
  G[b].push_back(a);
}
VP near(int i, int j, int H, int W) {
  VP ans;
  VP cand = {{i - 1, j}, {i + 1, j}, {i, j - 1}, {i, j + 1}};
  arep(v, cand) {
    if (v.first < 0 or v.first >= H) continue;
    if (v.second < 0 or v.second >= W) continue;
    ans.push_back(v);
  }
  return ans;
}
int cast(int i, int j, int H, int W) {
  return ((W * i) + j);
}
ll pows(ll x, ll n, ll mod) {
    if (!n) return 1;
    x %= mod;
    ll r = pows(x, n / 2, mod);
    (r *= r) %= mod;
    if (n % 2) (r *=x) %= mod;
    return r;
}
struct COMB_MOD {
  ll mod;
  int MAX;
  VL fac, finv, inv;
  COMB_MOD(int max, ll m) {
    fac.assign(max, 0);
    finv.assign(max, 0);
    inv.assign(max, 0);
    mod = m;
    MAX = max;
  }
  void solve() {
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++){
        fac[i] = fac[i - 1] * i % mod;

        inv[i] = mod - inv[mod%i] * (mod / i) % mod;
        finv[i] = finv[i - 1] * inv[i] % mod;
    }
    
  }
  ll comb(int n, int k) {
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % mod) % mod;
  }
};

struct LCA {
    vector<vector<int>> parent;  
    vector<int> dist;           
    LCA(const Graph0 &G, int root = 0) { init(G, root); }
    void init(const Graph0 &G, int root = 0) {
        int V = G.size();
        int K = 1;
        while ((1 << K) < V) K++;
        parent.assign(K, vector<int>(V, -1));
        dist.assign(V, -1);
        dfs(G, root, -1, 0);
        for (int k = 0; k + 1 < K; k++) {
            for (int v = 0; v < V; v++) {
                if (parent[k][v] < 0) {
                    parent[k + 1][v] = -1;
                } else {
                    parent[k + 1][v] = parent[k][parent[k][v]];
                }
            }
        }
    }
    void dfs(const Graph0 &G, int v, int p, int d) {
        parent[0][v] = p;
        dist[v] = d;
        for (auto e : G[v]) {
            if (e != p) dfs(G, e, v, d + 1);
        }
    }
    int query(int u, int v) {
        if (dist[u] < dist[v]) swap(u, v);  
        int K = parent.size();
        for (int k = 0; k < K; k++) {
            if ((dist[u] - dist[v]) >> k & 1) {
                u = parent[k][u];
            }
        }
        if (u == v) return u;
        for (int k = K - 1; k >= 0; k--) {
            if (parent[k][u] != parent[k][v]) {
                u = parent[k][u];
                v = parent[k][v];
            }
        }
        return parent[0][u];
    }
    int get_dist(int u, int v) { return dist[u] + dist[v] - 2 * dist[query(u, v)]; }
};
#include <atcoder/modint>
using namespace atcoder;
using mint = modint998244353;
int MN = 1e6;
Graph0 G(MN);
VL A(MN);
vector<mint> dp(MN, 0);
vector<mint> DP(MN, 0);
void dfs(int v, int par) {
    dp[v] = A[v];
    arep(nv, G[v]) {
        if (nv == par) continue;
        dfs(nv, v);
        dp[v] += A[v] * dp[nv];
    }
}
mint ret = 0;
void dfs1(int v, int par) {
    mint sm = A[v] * (1 + DP[v]);
    ret += dp[v] + (A[v] * DP[v]);
    arep(nv, G[v]) {
        if (nv == par) continue;
        sm += A[v] * dp[nv];
    }
    arep(nv, G[v]) {
        if (nv == par) continue;
        DP[nv] = sm -  (A[v] * dp[nv]);
        dfs1(nv, v);
    }
}
int main(void) {
    int N;cin >> N;
    rep(i, N) cin >> A[i];  
    rep(i, N-1) {
        int u, v;cin >> u >> v;
        u--;v--;
        ADD(u, v, G);
    }
    dfs(0, 0);
    dfs1(0, 0);
    rep(v, N) ret -= A[v];
    ret *= mint(2).inv();
    cout << ret.val() << endl;
}
0