#include long long int u[400005], v[400005]; long long int h[400005], l, z; long long int depth[200005], par[200005]; long long int comp_h(long long int a, long long int b) { if (z == 0) { if (u[h[a]] > u[h[b]]) return 1; else return -1; } else { if (depth[h[a]] < depth[h[b]]) return 1; else return -1; } } void swap_h(long long int a, long long int b) { long long int f = h[a]; h[a] = h[b]; h[b] = f; return; } void push(long long int ne) { h[l] = ne; long long int p = l; l++; for (; p > 0; p = (p - 1) / 2) if (comp_h((p - 1) / 2, p) > 0) swap_h((p - 1) / 2, p); return; } long long int pop() { l--; swap_h(0, l); long long int p = 0; for (;;) { if (2 * p + 2 < l) { if (comp_h(2 * p + 1, 2 * p + 2) > 0) { if (comp_h(p, 2 * p + 2) > 0) swap_h(p, 2 * p + 2); p = 2 * p + 2; } else { if (comp_h(p, 2 * p + 1) > 0) swap_h(p, 2 * p + 1); p = 2 * p + 1; } } else if (2 * p + 1 < l) { if (comp_h(p, 2 * p + 1) > 0) swap_h(p, 2 * p + 1); p = 2 * p + 1; } else break; } return h[l]; } long long int a[200005]; long long int c[400005]; long long int dp[200005]; int main() { long long int n; scanf("%lld", &n); long long int i; for (i = 0; i < n; i++) scanf("%lld", &a[i]); long long int m = n - 1; for (i = 0; i < m; i++) { scanf("%lld %lld", &u[i], &v[i]); u[i]--; v[i]--; u[i + m] = v[i]; v[i + m] = u[i]; } const long long int p = 998244353; m *= 2; l = 0; z = 0; for (i = 0; i < m; i++) push(i); for (i = 0; i < m; i++) c[i] = pop(); c[m] = m; u[m] = -1; for (i = 0; i < n; i++) depth[i] = -1; int min, mid, max; depth[0] = par[0] = 0; h[0] = 0; l = 1; while (l > 0) { l--; i = h[l]; min = -1; max = m; while (max - min > 1) { mid = (max + min) / 2; if (u[c[mid]] < i) min = mid; else max = mid; } for (; u[c[max]] == i; max++) { if (depth[v[c[max]]] < 0) { depth[v[c[max]]] = depth[i] + 1; par[v[c[max]]] = i; h[l] = v[c[max]]; l++; } } } for (i = 0; i < n; i++) dp[i] = 1; z = 1; l = 0; for (i = 0; i < n; i++) push(i); long long int ans = 0; while (l > 1) { i = pop(); dp[i] = dp[i] * a[i] % p; ans += dp[par[i]] * dp[i] % p * a[par[i]] % p; ans %= p; dp[par[i]] += dp[i]; dp[par[i]] %= p; } printf("%ld\n", ans); return 0; }