#line 2 "library/template/template.hpp" #include using namespace std; #line 2 "library/template/macro.hpp" #define rep(i, a, b) for (int i = (a); i < (int)(b); i++) #define rrep(i, a, b) for (int i = (int)(b) - 1; i >= (a); i--) #define ALL(v) (v).begin(), (v).end() #define UNIQUE(v) sort(ALL(v)), (v).erase(unique(ALL(v)), (v).end()) #define SZ(v) (int)v.size() #define MIN(v) *min_element(ALL(v)) #define MAX(v) *max_element(ALL(v)) #define LB(v, x) int(lower_bound(ALL(v), (x)) - (v).begin()) #define UB(v, x) int(upper_bound(ALL(v), (x)) - (v).begin()) #define YN(b) cout << ((b) ? "YES" : "NO") << "\n"; #define Yn(b) cout << ((b) ? "Yes" : "No") << "\n"; #define yn(b) cout << ((b) ? "yes" : "no") << "\n"; #line 6 "library/template/template.hpp" #line 2 "library/template/util.hpp" using uint = unsigned int; using ll = long long int; using ull = unsigned long long; using i128 = __int128_t; using u128 = __uint128_t; template S SUM(const vector& a) { return accumulate(ALL(a), S(0)); } template inline bool chmin(T& a, T b) { if (a > b) { a = b; return true; } return false; } template inline bool chmax(T& a, T b) { if (a < b) { a = b; return true; } return false; } template int popcnt(T x) { return __builtin_popcountll(x); } template int topbit(T x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); } template int lowbit(T x) { return (x == 0 ? -1 : __builtin_ctzll(x)); } #line 8 "library/template/template.hpp" #line 2 "library/template/inout.hpp" struct Fast { Fast() { cin.tie(nullptr); ios_base::sync_with_stdio(false); cout << fixed << setprecision(15); } } fast; template istream& operator>>(istream& is, pair& p) { return is >> p.first >> p.second; } template ostream& operator<<(ostream& os, const pair& p) { return os << p.first << " " << p.second; } template istream& operator>>(istream& is, vector& a) { for (auto& v : a) is >> v; return is; } template ostream& operator<<(ostream& os, const vector& a) { for (auto it = a.begin(); it != a.end();) { os << *it; if (++it != a.end()) os << " "; } return os; } template ostream& operator<<(ostream& os, const set& st) { os << "{"; for (auto it = st.begin(); it != st.end();) { os << *it; if (++it != st.end()) os << ","; } os << "}"; return os; } template ostream& operator<<(ostream& os, const map& mp) { os << "{"; for (auto it = mp.begin(); it != mp.end();) { os << it->first << ":" << it->second; if (++it != mp.end()) os << ","; } os << "}"; return os; } void in() {} template void in(T& t, U&... u) { cin >> t; in(u...); } void out() { cout << "\n"; } template void out(const T& t, const U&... u) { cout << t; if (sizeof...(u)) cout << sep; out(u...); } #line 10 "library/template/template.hpp" #line 2 "library/template/debug.hpp" #ifdef LOCAL #define debug 1 #define show(...) _show(0, #__VA_ARGS__, __VA_ARGS__) #else #define debug 0 #define show(...) true #endif template void _show(int i, T name) { cerr << '\n'; } template void _show(int i, const T1& a, const T2& b, const T3&... c) { for (; a[i] != ',' && a[i] != '\0'; i++) cerr << a[i]; cerr << ":" << b << " "; _show(i + 1, a, c...); } #line 2 "library/math/util.hpp" namespace Math { template T safe_mod(T a, T b) { assert(b != 0); if (b < 0) a = -a, b = -b; a %= b; return a >= 0 ? a : a + b; } template T floor(T a, T b) { assert(b != 0); if (b < 0) a = -a, b = -b; return a >= 0 ? a / b : (a + 1) / b - 1; } template T ceil(T a, T b) { assert(b != 0); if (b < 0) a = -a, b = -b; return a > 0 ? (a - 1) / b + 1 : a / b; } long long isqrt(long long n) { if (n <= 0) return 0; long long x = sqrt(n); while ((x + 1) * (x + 1) <= n) x++; while (x * x > n) x--; return x; } // return g=gcd(a,b) // a*x+b*y=g // - b!=0 -> 0<=x<|b|/g // - b=0 -> ax=g template T ext_gcd(T a, T b, T& x, T& y) { T a0 = a, b0 = b; bool sgn_a = a < 0, sgn_b = b < 0; if (sgn_a) a = -a; if (sgn_b) b = -b; if (b == 0) { x = sgn_a ? -1 : 1; y = 0; return a; } T x00 = 1, x01 = 0, x10 = 0, x11 = 1; while (b != 0) { T q = a / b, r = a - b * q; x00 -= q * x01; x10 -= q * x11; swap(x00, x01); swap(x10, x11); a = b, b = r; } x = x00, y = x10; if (sgn_a) x = -x; if (sgn_b) y = -y; if (b0 != 0) { a0 /= a, b0 /= a; if (b0 < 0) a0 = -a0, b0 = -b0; T q = x >= 0 ? x / b0 : (x + 1) / b0 - 1; x -= b0 * q; y += a0 * q; } return a; } constexpr long long inv_mod(long long x, long long m) { x %= m; if (x < 0) x += m; long long a = m, b = x; long long y0 = 0, y1 = 1; while (b > 0) { long long q = a / b; swap(a -= q * b, b); swap(y0 -= q * y1, y1); } if (y0 < 0) y0 += m / a; return y0; } long long pow_mod(long long x, long long n, long long m) { x = (x % m + m) % m; long long y = 1; while (n) { if (n & 1) y = y * x % m; x = x * x % m; n >>= 1; } return y; } constexpr long long pow_mod_constexpr(long long x, long long n, int m) { if (m == 1) return 0; unsigned int _m = (unsigned int)(m); unsigned long long r = 1; unsigned long long y = x % m; if (y >= m) y += m; while (n) { if (n & 1) r = (r * y) % _m; y = (y * y) % _m; n >>= 1; } return r; } constexpr bool is_prime_constexpr(int n) { if (n <= 1) return false; if (n == 2 || n == 7 || n == 61) return true; if (n % 2 == 0) return false; long long d = n - 1; while (d % 2 == 0) d /= 2; constexpr long long bases[3] = {2, 7, 61}; for (long long a : bases) { long long t = d; long long y = pow_mod_constexpr(a, t, n); while (t != n - 1 && y != 1 && y != n - 1) { y = y * y % n; t <<= 1; } if (y != n - 1 && t % 2 == 0) { return false; } } return true; } template constexpr bool is_prime = is_prime_constexpr(n); }; // namespace Math #line 3 "library/modint/modint.hpp" template struct ModInt { using mint = ModInt; static constexpr unsigned int get_mod() { return m; } static mint raw(int v) { mint x; x._v = v; return x; } ModInt() : _v(0) {} ModInt(int64_t v) { long long x = (long long)(v % (long long)(umod())); if (x < 0) x += umod(); _v = (unsigned int)(x); } unsigned int val() const { return _v; } mint& operator++() { _v++; if (_v == umod()) _v = 0; return *this; } mint& operator--() { if (_v == 0) _v = umod(); _v--; return *this; } mint operator++(int) { mint result = *this; ++*this; return result; } mint operator--(int) { mint result = *this; --*this; return result; } mint& operator+=(const mint& rhs) { _v += rhs._v; if (_v >= umod()) _v -= umod(); return *this; } mint& operator-=(const mint& rhs) { _v -= rhs._v; if (_v >= umod()) _v += umod(); return *this; } mint& operator*=(const mint& rhs) { unsigned long long z = _v; z *= rhs._v; _v = (unsigned int)(z % umod()); return *this; } mint& operator/=(const mint& rhs) { return *this *= rhs.inv(); } mint operator+() const { return *this; } mint operator-() const { return mint() - *this; } mint pow(long long n) const { assert(0 <= n); mint x = *this, r = 1; while (n) { if (n & 1) r *= x; x *= x; n >>= 1; } return r; } mint inv() const { if (is_prime) { assert(_v); return pow(umod() - 2); } else { auto inv = Math::inv_mod(_v, umod()); return raw(inv); } } friend mint operator+(const mint& lhs, const mint& rhs) { return mint(lhs) += rhs; } friend mint operator-(const mint& lhs, const mint& rhs) { return mint(lhs) -= rhs; } friend mint operator*(const mint& lhs, const mint& rhs) { return mint(lhs) *= rhs; } friend mint operator/(const mint& lhs, const mint& rhs) { return mint(lhs) /= rhs; } friend bool operator==(const mint& lhs, const mint& rhs) { return lhs._v == rhs._v; } friend bool operator!=(const mint& lhs, const mint& rhs) { return lhs._v != rhs._v; } friend istream& operator>>(istream& is, mint& x) { int64_t v; is >> v; x = mint(v); return is; } friend ostream& operator<<(ostream& os, const mint& x) { return os << x.val(); } private: unsigned int _v; static constexpr unsigned int umod() { return m; } static constexpr bool is_prime = Math::is_prime; }; #line 3 "main.cpp" using mint = ModInt<998244353>; void solve() { int n; in(n); vector a(n), b(n), c(n), p(n, -1); rep(i, 0, n) cin >> a[i]; rep(i, 1, n) cin >> b[i], b[i]--; rep(i, 1, n) cin >> c[i], c[i]--; rep(i, 1, n) cin >> p[i], p[i]--; vector>> g(n); rep(i, 1, n) g[p[i]].push_back({c[i], b[i], i}); rep(i, 0, n) sort(ALL(g[i])); mint ans = 0; vector dp_c(n), dp_v(n); mint inv2 = mint(2).inv(); mint inv6 = mint(6).inv(); rrep(x, 0, n) { dp_c[x] = a[x]; dp_v[x] = (mint(b[x]) * (b[x] + 1) + mint(a[x] - 1 - b[x]) * (a[x] - b[x])) * inv2; ans += mint(a[x]) * (a[x] - 1) * (a[x] + 1) * inv6; mint cnt = 0, val = 0; for (auto [c_, b_, y] : g[x]) { ans += a[x] * dp_v[y]; ans += (a[x] + mint(c_) * (c_ + 1) * inv2 + mint(a[x] - 1 - c_) * (a[x] - c_) * inv2) * dp_c[y]; ans += cnt * (dp_v[y] + dp_c[y]); ans += (val + cnt * c_) * dp_c[y]; cnt += dp_c[y]; val += dp_v[y] + dp_c[y] - c_ * dp_c[y]; dp_c[x] += dp_c[y]; dp_v[x] += dp_v[y] + (1 + abs(b[x] - c_)) * dp_c[y]; } } ans *= 2; out(ans); } int main() { int t = 1; // in(t); while (t--) solve(); }