#include #include using namespace std; using namespace atcoder; //using mint = modint1000000007; //const int mod = 1000000007; using mint = modint998244353; const int mod = 998244353; //const int INF = 1e9; //const long long LINF = 1e18; #define rep(i, n) for (int i = 0; i < (n); ++i) #define rep2(i,l,r)for(int i=(l);i<(r);++i) #define rrep(i, n) for (int i = (n) - 1; i >= 0; --i) #define rrep2(i,l,r)for(int i=(r) - 1;i>=(l);--i) #define all(x) (x).begin(),(x).end() #define allR(x) (x).rbegin(),(x).rend() #define P pair template inline bool chmax(A & a, const B & b) { if (a < b) { a = b; return true; } return false; } template inline bool chmin(A & a, const B & b) { if (a > b) { a = b; return true; } return false; } int main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); int n; cin >> n; vectora(n), b(n - 1), c(n - 1), p(n - 1); rep(i, n)cin >> a[i]; rep(i, n - 1)cin >> b[i], b[i]--; rep(i, n - 1)cin >> c[i], c[i]--; rep(i, n - 1)cin >> p[i], p[i]--; vector>mps(n); vector to(n, vector()); rep(i, n - 1) { to[p[i]].push_back(i + 1); } vectorchild(n); mint sum = 0; rep(i, n)sum += a[i]; { auto dfs = [&](auto &&self, int v, int p = -1)->void { child[v] = a[v]; for (auto nv : to[v]) { if (p == nv)continue; self(self, nv, v); child[v] += child[nv]; } }; dfs(dfs, 0); } rep(i, n) { mps[i][0] = 0; mps[i][a[i] - 1] = 0; } rep(i, n - 1) { mps[p[i]][c[i]] += child[i + 1]; mps[i + 1][b[i]] += sum - child[i + 1]; } mint ans = 0; auto calc = [&](int v, map&mp)->void { int pre = -1; mint tmp = 0; for (auto[k, v] : mp) { if (pre == -1) { pre = k; tmp += v; continue;; } mint sz = k - pre; mint l = tmp + pre + 1; mint r = sum - l; ans += l * r* sz; //cout << l.val() << " " << r.val() << " " << sz.val() << endl; mint x = sz * (sz - 1) / 2; ans += x * (r - l); ans -= (sz - 1)*(sz) * (2 * sz - 1) / 6; pre = k; tmp += v; } }; auto dfs = [&](auto &&self, int v, int p = -1)->void { ans += (sum - child[v]) *child[v]; calc(v, mps[v]); for (auto nv : to[v]) { if (p == nv)continue; self(self, nv, v); } }; dfs(dfs, 0); ans *= 2; cout << ans.val() << endl; return 0; }