#include using namespace std; //入力が必ず-mod<=a //mod<2^30. struct modint{ //mod変更が不可能. public: long long v; static void setmod(int m){} //飾り. static constexpr long long getmod(){return mod;} modint():v(0){} template modint(T a):v(a){if(v < 0) v += mod;} long long val()const{return v;} modint &operator=(const modint &b) = default; modint &operator+()const{return (*this);} modint operator-()const{return modint(0)-(*this);} modint operator+(const modint b)const{return modint(v)+=b;} modint operator-(const modint b)const{return modint(v)-=b;} modint operator*(const modint b)const{return modint(v)*=b;} modint operator/(const modint b)const{return modint(v)/=b;} modint &operator+=(const modint b){ v += b.v; if(v >= mod) v -= mod; return *this; } modint &operator-=(const modint b){ v -= b.v; if(v < 0) v += mod; return *this; } modint &operator*=(const modint b){v = v*b.v%mod; return *this;} modint &operator/=(modint b){ //b!=0 mod素数が必須. assert(b.v != 0); (*this) *= b.pow(mod-2); return *this; } modint pow(long long n)const{ modint ret = 1,p = v; if(n < 0) p = p.inv(),n = -n; while(n){ if(n&1) ret *= p; p *= p; n >>= 1; } return ret; } modint inv()const{return pow(mod-2);} //素数mod必須. modint &operator++(){*this += 1; return *this;} modint &operator--(){*this -= 1; return *this;} modint operator++(int){modint ret = *this; *this += 1; return ret;} modint operator--(int){modint ret = *this; *this -= 1; return ret;} friend bool operator==(const modint a,const modint b){return a.v==b.v;} friend bool operator!=(const modint a,const modint b){return a.v!=b.v;} friend bool operator<(const modint a,const modint b){return a.v=(const modint a,const modint b){return a.v>=b.v;} friend bool operator>(const modint a,const modint b){return a.v>b.v;} friend ostream &operator<<(ostream &os,const modint a){return os<>(istream &is,modint &a){ //入力はmodをとってくれる. long long x; is >> x; x %= mod; a = modint(x); return is; } }; using mint = modint<998244353>; const long long mod = 998244353; vector BFS(vector> &Graph,int start){ int N = Graph.size(); vector ret(N,-1); queue Q; ret.at(start) = 0,Q.push(start); while(Q.size()){ int pos = Q.front(); Q.pop(); for(auto to : Graph.at(pos)){ if(ret.at(to) != -1) continue; ret.at(to) = ret.at(pos)+1; Q.push(to); } } return ret; } int main(){ ios_base::sync_with_stdio(false); cin.tie(nullptr); int N; cin >> N; vector A(N),B(N-1),C(N-1),P(N-1); for(auto &a : A) cin >> a; for(auto &a : B) cin >> a,a--; for(auto &a : C) cin >> a,a--; for(auto &a : P) cin >> a,a--; { int n = accumulate(A.begin(),A.end(),0); vector> G(n); vector S(N); for(int i=0; i>> Graph(N); for(int i=0; i pair { mint siz = 0,sum = 0; long long a = A.at(pos); answer += a*((a-1)*a/2%mod)%mod; answer -= div6*((a-1)*a%mod)*((2*a-1)%mod); auto f = [&](long long x) -> mint {return (x*(x+1)/2%mod + (a-x)*(a-x-1)/2%mod)%mod;}; siz += a,sum += f(back); mint ksiz = 0,ksum = 0,ksum2 = 0; for(auto [b,c,to] : Graph.at(pos)){ auto [k1,k2] = dfs(dfs,to,c); answer += k2*a+k1*f(b); siz += k1,sum += k2+k1*(abs(b-back)%mod); answer += (ksiz*b-ksum)*k1; answer += ksiz*k2+ksum2*k1; ksiz += k1,ksum += k1*b,ksum2 += k2; } return {siz,sum+siz}; }; dfs(dfs,0,0); cout << answer*2 << "\n"; }