import sys; sys.setrecursionlimit(2 * 10 ** 5) N = int(input()) g = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, input().split()) u -= 1 v -= 1 g[u].append(v) g[v].append(u) a = list(map(int, input().split())) vis = [False] * N ans = 0 def dfs(u): global ans d = [] vis[u] = True res = 0 s = 1 j = -1 l = [] r = [] for i in g[u]: if vis[i]: if a[i] > a[u]: j = len(l) l.append(0) if a[i] < a[u]: j = len(r) + N r.append(0) else: t = dfs(i) s += t if a[i] > a[u]: l.append(t) if a[i] < a[u]: r.append(t) if j != -1: if j < N: l[j] = N - s else: r[j - N] = N - s ans += (sum(l) ** 2 - sum(i ** 2 for i in l)) // 2 + (sum(r) ** 2 - sum(i ** 2 for i in r)) // 2 return s dfs(0) print(ans % 998244353)