結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー cho435
提出日時 2025-11-14 03:27:46
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 163 ms / 2,000 ms
コード長 2,678 bytes
コンパイル時間 4,202 ms
コンパイル使用メモリ 264,388 KB
実行使用メモリ 10,984 KB
最終ジャッジ日時 2025-11-14 03:27:55
合計ジャッジ時間 8,628 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using ll = long long;
#define rep(i, s, t) for (ll i = s; i < (ll)(t); i++)
#define all(x) begin(x), end(x)

template <class T> bool chmin(T& x, T y) {
	return x > y ? (x = y, true) : false;
}
template <class T> bool chmax(T& x, T y) {
	return x < y ? (x = y, true) : false;
}

struct io_setup {
	io_setup() {
		ios::sync_with_stdio(false);
		cin.tie(nullptr);
		cout << fixed << setprecision(15);
	}
} io_setup;

using mint = atcoder::modint998244353;

void solve() {
	int n;
	cin >> n;
	vector<int> a(n), b(n - 1), c(n - 1), p(n - 1);
	rep(i, 0, n) cin >> a[i];
	rep(i, 0, n - 1) cin >> b[i], b[i]--;
	rep(i, 0, n - 1) cin >> c[i], c[i]--;
	rep(i, 0, n - 1) cin >> p[i], p[i]--;
	// auto greedy = [&]() {
	// 	vector<ll> sa(n + 1);
	// 	rep(i, 0, n) sa[i + 1] = sa[i] + a[i];
	// 	int m = sa[n];
	// 	vector<vector<int>> g(m, vector<int>(m, 1e9));
	// 	rep(i, 0, n) {
	// 		rep(j, sa[i], sa[i + 1] - 1) {
	// 			g[j][j + 1] = 1;
	// 			g[j + 1][j] = 1;
	// 		}
	// 	}
	// 	rep(i, 0, n - 1) {
	// 		g[sa[i + 1] + b[i]][sa[p[i]] + c[i]] = 1;
	// 		g[sa[p[i]] + c[i]][sa[i + 1] + b[i]] = 1;
	// 	}
	// 	rep(i, 0, m) g[i][i] = 0;
	// 	rep(k, 0, m) rep(i, 0, m) rep(j, 0, m)
	// 		chmin(g[i][j], g[i][k] + g[k][j]);
	// 	// rep(i, 0, m) {
	// 	// 	rep(j, 0, m) cout << g[i][j] << ' ';
	// 	// 	cout << endl;
	// 	// }
	// 	mint ans = 0;
	// 	rep(i, 0, m) rep(j, 0, m) ans += g[i][j];
	// 	return ans;
	// };
	auto sum1 = [](ll x) {
		return mint(x + 1) * x / 2;
	};
	auto sum2 = [](ll x) {
		return mint(x + 1) * x * (2 * x + 1) / 6;
	};
	auto f = [&](ll x, ll y) {
		return sum1(x - 1) - sum1(y) * 2 + mint(2 + 2 * y - x) * y;
	};
	mint ans = 0;
	rep(i, 0, n) {
		ans += a[i] * sum1(a[i]) - sum2(a[i]);
	}
	vector<vector<array<int, 3>>> g(n);
	rep(i, 0, n - 1) {
		g[i + 1].push_back({b[i], p[i], c[i]});
		g[p[i]].push_back({c[i], (int)i + 1, b[i]});
	}
	rep(i, 0, n) sort(g[i].begin(), g[i].end());
	auto dfs = [&](auto self, int nw, int pr, int bb) -> pair<mint, mint> {
		mint sz = 0, sm = 0;
		mint tmp_sm = 0;
		int tmp_d = -1;
		for (auto [bd, nx, cd] : g[nw]) {
			if (nx == pr) continue;
			auto [csz, csm] = self(self, nx, nw, cd);
			csm += csz;
			assert(bd >= tmp_d);
			tmp_sm += (bd - tmp_d) * sz;
			ans += tmp_sm * csz;
			ans += sz * csm;
			ans += csm * a[nw];
			ans += f(a[nw], bd) * csz;
			sz += csz;
			sm += csm + abs(bd - bb) * csz;
			tmp_sm += csm;
			tmp_d = bd;
		}
		sz += a[nw];
		sm += f(a[nw], bb);
		return {sz, sm};
	};
	dfs(dfs, 0, -1, 0);
	ans *= 2;
	cout << ans.val() << '\n';
	// assert(greedy() == ans);
}

int main() {
	int t = 1;
	// cin >> t;
	while (t--) solve();
}
0