結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー こめだわら
提出日時 2025-11-14 00:26:36
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 191 ms / 2,000 ms
コード長 3,167 bytes
コンパイル時間 6,711 ms
コンパイル使用メモリ 335,884 KB
実行使用メモリ 15,240 KB
最終ジャッジ日時 2025-11-14 00:26:48
合計ジャッジ時間 10,997 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using ull = unsigned long long;

#define rep(i,n) for(ll i=0;i<n;++i)
#define all(a) (a).begin(),(a).end()
ll intpow(ll a, ll b){ ll ans = 1; while(b){ if(b & 1) ans *= a; a *= a; b /= 2; } return ans; }
ll modpow(ll a, ll b, ll p){ ll ans = 1; while(b){ if(b & 1) (ans *= a) %= p; (a *= a) %= p; b /= 2; } return ans; }
template<class T> T div_floor(T a, T b) { return a / b - ((a ^ b) < 0 && a % b); }
template<class T> T div_ceil(T a, T b) { return a / b + ((a ^ b) > 0 && a % b); }
template <typename T, typename U> inline bool chmin(T &x, U y) { return (y < x) ? (x = y, true) : false; }
template <typename T, typename U> inline bool chmax(T &x, U y) { return (x < y) ? (x = y, true) : false; }

template<typename T>
ostream &operator<<(ostream &os, const vector<T> &a){
    if (a.empty()) return os;
    os << a.front();
    for (auto e : a | views::drop(1)){
        os << ' ' << e;
    }
    return os;
}

void dump(auto ...vs){
    ((cout << vs << ' '), ...) << endl;
}

#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;

void solve() {
    ll N;
    cin>>N;
    vector<ll> A(N);
    ll sa=0;
    rep(i,N){
        cin>>A[i];
        sa+=A[i];
    }
    vector<vector<pair<ll,ll>>> child(N);
    vector<pair<ll,ll>> parent(N);
    vector<ll> B(N);
    rep(i,N-1){
        cin>>B[i+1];
        B[i+1]--;
    }
    vector<ll> C(N);
    rep(i,N-1){
        cin>>C[i+1];
        C[i+1]--;
    }
    vector<ll> P(N);
    rep(i,N-1){
        cin>>P[i+1];
        P[i+1]--;
    }
    for (ll i=1;i<N;i++){
        parent[i]={B[i],P[i]};
        child[P[i]].emplace_back(C[i],i);
    }
    vector<ll> S(N);
    auto dfs=[&](auto self,ll cp)->ll {
        ll v=A[cp];
        for (auto [_,np]:child[cp]){
            v+=self(self,np);
        }
        S[cp]=v;
        return v;
    };
    dfs(dfs,0);
    mint ans=0;
    for (ll cp=1;cp<N;cp++){
        mint v=S[cp];
        ans+=v*(sa-v);
    }
    rep(cp,N){
        vector<pair<ll,mint>> C;
        for (auto [a,np]:child[cp]){
            C.emplace_back(a,S[np]);
        }
        if (cp!=0){
            C.emplace_back(parent[cp].first,sa-S[cp]);
        }
        sort(all(C),[](pair<ll,mint> a,pair<ll,mint> b){
            return (a.first<b.first);
        });
        C.emplace_back(A[cp]-1,0);
        ll cs=C.size();
        mint rs=0;
        for (auto [_,v]:C){
            rs+=v;
        }
        // dump(rs.val());
        mint ls=0;
        ll nowa=1;
        for (auto [b,v]:C){
            mint right=rs+A[cp];
            mint left=ls;
            mint t=0;
            t+=(mint)(b-nowa+1)*left*right;
            t+=(mint)(right-left)*(b-nowa+1)*(nowa+b)/2;
            t-=(mint)(b+1)*b*(2*b+1)/6;
            t+=(mint)(nowa-1)*nowa*(2*nowa-1)/6;
            // dump(cp,b,left.val(),right.val(),t.val());
            ans+=t;
            ls+=v;
            rs-=v;
            nowa=b+1;
        }
    }
    ans*=2;
    cout<<ans.val()<<'\n';
    return;
}


int main() {
    cin.tie(0)->sync_with_stdio(0);
    ll T=1;
    while (T--){
        solve();
    }
    return 0;
}
0