#include using namespace std; #pragma GCC optimize("O3") #define rep(i,n) for(ll i=0;i=0;i--) #define perl(i,r,l) for(ll i=r-1;i>=l;i--) #define fi first #define se second #define ins insert #define pqueue(x) priority_queue,greater> #define all(x) (x).begin(),(x).end() #define CST(x) cout<; using vvl=vector>; using pl=pair; using vpl=vector; using vvpl=vector; const ll MOD=1000000007; const ll MOD9=998244353; const int inf=1e9+10; const ll INF=4e18; const ll dy[9]={1,0,-1,0,1,1,-1,-1,0}; const ll dx[9]={0,1,0,-1,1,-1,1,-1,0}; template inline bool chmax(T &a, T b) { return ((a < b) ? (a = b, true) : (false)); } template inline bool chmin(T &a, T b) { return ((a > b) ? (a = b, true) : (false)); } ll n; vl a(5e4+10); vvl g(5e4+10); ll ans=0; ll dfs(ll v,ll par=-1){ vpl pls;//{value,index} for(auto p:g[v]){ if(p==par)continue; ll f=dfs(p,v); pls.emplace_back(make_pair(f,p)); } sort(all(pls));rev(all(pls)); ll upscore=0,updeg=0; set seen; for(auto [value,idx]:pls){ if(a[idx]<=a[v])continue; if(seen.count(a[idx]))continue; if(value<0)continue; seen.insert(a[idx]); upscore+=value;updeg++; } ll downscore=0,downdeg=0; seen.clear(); for(auto [value,idx]:pls){ if(a[idx]>=a[v])continue; if(seen.count(a[idx]))continue; if(value<0)continue; seen.insert(a[idx]); downscore+=value;downdeg++; } chmax(ans,upscore+updeg*(updeg-1)); chmax(ans,downscore+downdeg*(downdeg-1)); //上への遷移 if(v==0)return 0; if(a[par]==a[v])return -INF; upscore=0,updeg=1; seen.clear(); for(auto [value,idx]:pls){ if(a[idx]<=a[v])continue; if(a[idx]==a[par])continue; if(seen.count(a[idx]))continue; if(value<0)continue; seen.insert(a[idx]); upscore+=value;updeg++; } downscore=0,downdeg=1; seen.clear(); for(auto [value,idx]:pls){ if(a[idx]>=a[v])continue; if(a[idx]==a[par])continue; if(seen.count(a[idx]))continue; if(value<0)continue; seen.insert(a[idx]); downscore+=value;downdeg++; } if(a[par]> n; rep(i,n)cin >> a[i]; rep(i,n-1){ ll x,y;cin >> x >> y;x--;y--; g[x].emplace_back(y); g[y].emplace_back(x); } dfs(0,-1); cout << ans/2 << endl; }