結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー cho435
提出日時 2025-11-14 01:25:39
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 1,735 bytes
コンパイル時間 4,209 ms
コンパイル使用メモリ 260,188 KB
実行使用メモリ 10,988 KB
最終ジャッジ日時 2025-11-14 01:25:48
合計ジャッジ時間 7,980 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 1 WA * 29
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using ll = long long;
#define rep(i, s, t) for (ll i = s; i < (ll)(t); i++)
#define all(x) begin(x), end(x)

template <class T> bool chmin(T& x, T y) {
	return x > y ? (x = y, true) : false;
}
template <class T> bool chmax(T& x, T y) {
	return x < y ? (x = y, true) : false;
}

struct io_setup {
	io_setup() {
		ios::sync_with_stdio(false);
		cin.tie(nullptr);
		cout << fixed << setprecision(15);
	}
} io_setup;

using mint = atcoder::modint998244353;

void solve() {
	int n;
	cin >> n;
	vector<int> a(n), b(n - 1), c(n - 1), p(n - 1);
	rep(i, 0, n) cin >> a[i];
	rep(i, 0, n - 1) cin >> b[i], b[i]--;
	rep(i, 0, n - 1) cin >> c[i], c[i]--;
	rep(i, 0, n - 1) cin >> p[i], p[i]--;
	p.insert(p.begin(), -1);
	auto sum1 = [](ll x) {
		return mint(x + 1) * x / 2;
	};
	auto sum2 = [](ll x) {
		return mint(x + 1) * x * (2 * x + 1) / 6;
	};
	auto f = [&](ll x, ll y) {
		return sum1(x - 1) - sum2(y) * 2 + (2 + 2 * y - x) * y;
	};
	mint ans = 0;
	rep(i, 0, n) {
		ans += a[i] * sum1(a[i]) - sum2(a[i]);
	}
	vector<vector<array<int, 3>>> g(n);
	rep(i, 1, n) {
		g[i].push_back({p[i], b[i - 1], c[i - 1]});
		g[p[i]].push_back({(int)i, c[i - 1], b[i - 1]});
	}
	auto dfs = [&](auto self, int nw, int pr, int bb) -> pair<mint, mint> {
		mint sz = 0, sm = 0;
		for (auto [nx, bd, cd] : g[nw]) {
			if (nx == pr) continue;
			auto [csz, csm] = self(self, nx, nw, cd);
			csm += csz;
			ans += csm * a[nw];
			ans += f(a[nw], bd) * csz;
			sz += csz;
			sm += csm + abs(bd - bb) * csz;
		}
		sz += a[nw];
		sm += f(a[nw], bb);
		return {sz, sm};
	};
	dfs(dfs, 0, -1, 0);
	ans *= 2;
	cout << ans.val() << '\n';
}

int main() {
	int t = 1;
	// cin >> t;
	while (t--) solve();
}
0