結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー ponjuice
提出日時 2025-10-29 18:07:03
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
(最新)
AC  
(最初)
実行時間 -
コード長 2,207 bytes
コンパイル時間 3,329 ms
コンパイル使用メモリ 304,480 KB
実行使用メモリ 90,968 KB
最終ジャッジ日時 2025-11-13 21:06:17
合計ジャッジ時間 12,832 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 2
other WA * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i,a,b) for(ll i = a; i < (b); i++)

int main() {
    int n;
    cin >> n;
    ll sum = 0;
    vector<int> a(n), b(n), c(n), p(n);
    rep(i, 0, n) cin >> a[i], sum += a[i];
    rep(i, 1, n) cin >> b[i];
    rep(i, 1, n) cin >> c[i];
    rep(i, 1, n) cin >> p[i], p[i]--;

    vector<set<int>> use(n);
    rep(i,0,n) use[i].insert(1);
    rep(i,0,n) use[i].insert(a[i]);
    rep(i,1,n) {
        use[i].insert(b[i]);
        use[p[i]].insert(c[i]);
    }

    map<array<ll,2>, int> mp;
    ll gn = 0;
    rep(i,0,n){
        for(auto x: use[i]) {
            mp[{i,x}] = gn++;
        }
    }

    vector<vector<array<ll,2>>> g(gn);
    rep(i,0,n){
        int bef = -1;
        for(auto x: use[i]) {
            if(bef == -1) {
                bef = x;
                continue;
            }

            g[mp[{i,x}]].push_back({mp[{i,bef}], x-bef});
            g[mp[{i,bef}]].push_back({mp[{i,x}], x-bef});

            bef = x;
        }
    }
    rep(i,1,n){
        g[mp[{p[i], c[i]}]].push_back({mp[{i, b[i]}], 1});
        g[mp[{i, b[i]}]].push_back({mp[{p[i], c[i]}], 1});
    }

    ll ans = 0;
    const ll mod = 998244353;
    auto inv = [&](ll a) -> ll {
        ll x = mod-2;
        ll res = 1;
        while(x > 0){
            if(x&1) res = res * a % mod;
            a = a * a % mod;
            x >>= 1;
        }
        return res;
    };

    auto dfs = [&](auto&& self, int nw, int par) -> ll {
        ll now = 1;
        for(auto to: g[nw]) {
            if(to[0] == par) continue;
            now += self(self, to[0], nw) + to[1]-1;
        }

        for(auto to: g[nw]) {
            if(to[0] == par) {
                ll here = now;
                ll there = sum - here - (to[1] - 1);
                ans += (to[1]%mod) * (here%mod) % mod * (there%mod) % mod;
                ans += (to[1] % mod) * ((to[1]-1) % mod) % mod * ((here+there)%mod) % mod * inv(2) % mod;
                ans += ((to[1]-2)%mod) * ((to[1]-1)%mod) % mod * ((to[1])%mod) % mod * inv(6) % mod; 
                ans %= mod;
            }
        }

        return now;
    };
    dfs(dfs, 0, -1);

    cout << ans << endl;
}

0