結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー jupiter_68
提出日時 2025-11-05 20:55:10
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 4,405 bytes
コンパイル時間 2,240 ms
コンパイル使用メモリ 214,244 KB
実行使用メモリ 14,244 KB
最終ジャッジ日時 2025-11-13 21:11:36
合計ジャッジ時間 5,259 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 2
other WA * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pll = pair<ll, ll>;
using vl = vector<ll>;
template <class T> using vec = vector<T>;
template <class T> using vv = vec<vec<T>>;
template <class T> using vvv = vv<vec<T>>;
template <class T> using minpq = priority_queue<T, vector<T>, greater<T>>;
#define all(a) (a).begin(),(a).end()
#define rep(i, n) for (ll i = 0; i < (n); ++i)
#define reps(i, l, r) for(ll i = (l); i < (r); ++i)
#define rrep(i, l, r) for(ll i = (r)-1; i >= (l); --i)
#define sz(x) (ll) (x).size()
template <typename T>
bool chmax(T &a, const T& b) { return a < b ? a = b, true : false; }
template <typename T>
bool chmin(T &a, const T& b) { return a > b ? a = b, true : false; }

struct Edge {
    ll from, to, cost;
    Edge (ll from, ll to, ll cost = 1ll) : from(from), to(to), cost(cost) {}
};
struct Graph {
    vector<vector<Edge>> G;
    Graph() = default;
    explicit Graph(ll N) : G(N) {}
    size_t size() const {
        return G.size();
    }
    void add(ll from, ll to, ll cost = 1ll, bool direct = 0) {
        G[from].emplace_back(from, to, cost);
        if (!direct) G[to].emplace_back(to, from, cost);
    }
    vector<Edge> &operator[](const int &k) {
        return G[k];
    }
};
using Edges = vector<Edge>;

const ll mod = 998244353;

struct mint {
	ll x;
	mint(ll y = 0) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
	mint &operator+=(const mint &p) {
		if ((x += p.x) >= mod) x -= mod;
		return *this;
	}
	mint &operator-=(const mint &p) {
		if ((x += mod - p.x) >= mod) x -= mod;
		return *this;
	}
	mint &operator*=(const mint &p) {
		x = (ll)(1ll * x * p.x % mod);
		return *this;
	}
	mint &operator/=(const mint &p) {
		*this *= p.inv();
		return *this;
	}
	mint operator-() const { return mint(-x); }
	mint operator+(const mint &p) const { return mint(*this) += p; }
	mint operator-(const mint &p) const { return mint(*this) -= p; }
	mint operator*(const mint &p) const { return mint(*this) *= p; }
	mint operator/(const mint &p) const { return mint(*this) /= p; }
	bool operator==(const mint &p) const { return x == p.x; }
	bool operator!=(const mint &p) const { return x != p.x; }
	friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
	friend istream &operator>>(istream &is, mint &a) {
		ll t; is >> t; a = mint(t); return (is);
	}
	mint inv() const { return pow(mod - 2); }
	mint pow(ll n) const {
		mint ret(1), mul(x);
		while (n > 0) {
			if (n & 1) ret *= mul;
			mul *= mul;
			n >>= 1;
		}
		return ret;
	}
};

void solve(){
    ll N; cin >> N;
    vl a(N); rep(i, N) cin >> a[i];
    vl b(N-1); rep(i, N-1) cin >> b[i];
    vl c(N-1); rep(i, N-1) cin >> c[i];
    vl d(N-1); rep(i, N-1) cin >> d[i];
    vec<mint> s(N);
    vv<pll> G(N);
    rep(i, N-1){
        G[i+1].push_back({b[i], d[i]-1});
        G[d[i]-1].push_back({c[i], i+1});
    }
    rep(i, N){
        sort(all(G[i]));
    }
    {
        auto dfs = [&](auto dfs, ll pos, ll pre)->void{
            s[pos] = a[pos];
            for(pll e : G[pos]){
                ll nx = e.second;
                if(nx == pre) continue;
                dfs(dfs, nx, pos);
                s[pos] += s[nx];
            }
            return;
        };
        dfs(dfs, 0, -1);
    }
    mint ans = 0;
    auto f = [&](mint l, mint r)->mint{
        if(l == r){
            return 0;
        }
        mint re = s[0];
        re *= (r - l);
        re *= (l + r - 1);
        re /= mint(2);
        mint dec1 = r, dec2 = l;
        dec1 *= r-1, dec2 *= l-1;
        dec1 *= r * 2 - 1, dec2 *= l * 2 - 1;
        dec1 -= dec2;
        dec1 /= mint(6);
        re -= dec1;
        return re;
    };
    auto dfs = [&](auto dfs, ll pos, ll pre)->void{
        mint x = 1;
        for(pll e : G[pos]){
            ll y = e.first;
            ans += f(x, y);
            x = y;
            ll nx = e.second;
            if(nx == pre){
                ans += s[pos] * (s[0] - s[pos]);
                x += (s[0] - s[pos]);
            }else{
                dfs(dfs, nx, pos);
                x += s[nx];
            }
        }
        ans += f(x, s[0]);
        return;
    };
    dfs(dfs, 0, -1);
    cout << ans << endl;
    return;
}

int main(){
    cin.tie(nullptr);
    ios_base::sync_with_stdio(false);
    cout << fixed << setprecision(20);
    int t = 1;
    // cin >> t;
    while(t--){
        solve();
    }
}
0