結果
| 問題 |
No.3343 Distance Sum of Large Tree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-11-05 20:55:10 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 4,405 bytes |
| コンパイル時間 | 2,240 ms |
| コンパイル使用メモリ | 214,244 KB |
| 実行使用メモリ | 14,244 KB |
| 最終ジャッジ日時 | 2025-11-13 21:11:36 |
| 合計ジャッジ時間 | 5,259 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | WA * 2 |
| other | WA * 30 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pll = pair<ll, ll>;
using vl = vector<ll>;
template <class T> using vec = vector<T>;
template <class T> using vv = vec<vec<T>>;
template <class T> using vvv = vv<vec<T>>;
template <class T> using minpq = priority_queue<T, vector<T>, greater<T>>;
#define all(a) (a).begin(),(a).end()
#define rep(i, n) for (ll i = 0; i < (n); ++i)
#define reps(i, l, r) for(ll i = (l); i < (r); ++i)
#define rrep(i, l, r) for(ll i = (r)-1; i >= (l); --i)
#define sz(x) (ll) (x).size()
template <typename T>
bool chmax(T &a, const T& b) { return a < b ? a = b, true : false; }
template <typename T>
bool chmin(T &a, const T& b) { return a > b ? a = b, true : false; }
struct Edge {
ll from, to, cost;
Edge (ll from, ll to, ll cost = 1ll) : from(from), to(to), cost(cost) {}
};
struct Graph {
vector<vector<Edge>> G;
Graph() = default;
explicit Graph(ll N) : G(N) {}
size_t size() const {
return G.size();
}
void add(ll from, ll to, ll cost = 1ll, bool direct = 0) {
G[from].emplace_back(from, to, cost);
if (!direct) G[to].emplace_back(to, from, cost);
}
vector<Edge> &operator[](const int &k) {
return G[k];
}
};
using Edges = vector<Edge>;
const ll mod = 998244353;
struct mint {
ll x;
mint(ll y = 0) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
mint &operator+=(const mint &p) {
if ((x += p.x) >= mod) x -= mod;
return *this;
}
mint &operator-=(const mint &p) {
if ((x += mod - p.x) >= mod) x -= mod;
return *this;
}
mint &operator*=(const mint &p) {
x = (ll)(1ll * x * p.x % mod);
return *this;
}
mint &operator/=(const mint &p) {
*this *= p.inv();
return *this;
}
mint operator-() const { return mint(-x); }
mint operator+(const mint &p) const { return mint(*this) += p; }
mint operator-(const mint &p) const { return mint(*this) -= p; }
mint operator*(const mint &p) const { return mint(*this) *= p; }
mint operator/(const mint &p) const { return mint(*this) /= p; }
bool operator==(const mint &p) const { return x == p.x; }
bool operator!=(const mint &p) const { return x != p.x; }
friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
friend istream &operator>>(istream &is, mint &a) {
ll t; is >> t; a = mint(t); return (is);
}
mint inv() const { return pow(mod - 2); }
mint pow(ll n) const {
mint ret(1), mul(x);
while (n > 0) {
if (n & 1) ret *= mul;
mul *= mul;
n >>= 1;
}
return ret;
}
};
void solve(){
ll N; cin >> N;
vl a(N); rep(i, N) cin >> a[i];
vl b(N-1); rep(i, N-1) cin >> b[i];
vl c(N-1); rep(i, N-1) cin >> c[i];
vl d(N-1); rep(i, N-1) cin >> d[i];
vec<mint> s(N);
vv<pll> G(N);
rep(i, N-1){
G[i+1].push_back({b[i], d[i]-1});
G[d[i]-1].push_back({c[i], i+1});
}
rep(i, N){
sort(all(G[i]));
}
{
auto dfs = [&](auto dfs, ll pos, ll pre)->void{
s[pos] = a[pos];
for(pll e : G[pos]){
ll nx = e.second;
if(nx == pre) continue;
dfs(dfs, nx, pos);
s[pos] += s[nx];
}
return;
};
dfs(dfs, 0, -1);
}
mint ans = 0;
auto f = [&](mint l, mint r)->mint{
if(l == r){
return 0;
}
mint re = s[0];
re *= (r - l);
re *= (l + r - 1);
re /= mint(2);
mint dec1 = r, dec2 = l;
dec1 *= r-1, dec2 *= l-1;
dec1 *= r * 2 - 1, dec2 *= l * 2 - 1;
dec1 -= dec2;
dec1 /= mint(6);
re -= dec1;
return re;
};
auto dfs = [&](auto dfs, ll pos, ll pre)->void{
mint x = 1;
for(pll e : G[pos]){
ll y = e.first;
ans += f(x, y);
x = y;
ll nx = e.second;
if(nx == pre){
ans += s[pos] * (s[0] - s[pos]);
x += (s[0] - s[pos]);
}else{
dfs(dfs, nx, pos);
x += s[nx];
}
}
ans += f(x, s[0]);
return;
};
dfs(dfs, 0, -1);
cout << ans << endl;
return;
}
int main(){
cin.tie(nullptr);
ios_base::sync_with_stdio(false);
cout << fixed << setprecision(20);
int t = 1;
// cin >> t;
while(t--){
solve();
}
}