#include #include #include #include #include #include #include #include #include namespace ranges = std::ranges; namespace views = std::views; // #include "Src/Number/IntegerDivision.hpp" // #include "Src/Utility/BinarySearch.hpp" // #include "Src/Sequence/CompressedSequence.hpp" // #include "Src/Sequence/RunLengthEncoding.hpp" // #include "Src/Algebra/Group/AdditiveGroup.hpp" // #include "Src/DataStructure/FenwickTree/FenwickTree.hpp" // #include "Src/DataStructure/SegmentTree/SegmentTree.hpp" // #include "Src/DataStructure/DisjointSetUnion/DisjointSetUnion.hpp" // using namespace zawa; // #include "atcoder/modint" // using mint = atcoder::modint998244353; #include using namespace std; int N, M, P[100010], C[100010]; vector g[100010]; int main() { cin.tie(0); cout.tie(0); ios::sync_with_stdio(0); cin >> N >> M; for (int i = 0 ; i < N ; i++) cin >> P[i]; for (int i = 0 ; i < N - 1 ; i++) { int a, b; cin >> a >> b; a--; b--; g[a].push_back(b); g[b].push_back(a); } vector dist(N, -1); { vector que; for (int i = 0 ; i < M ; i++) { int C; cin >> C; C--; dist[C] = 0; que.push_back(C); } for (int t = 0 ; t < ssize(que) ; t++) { const int v = que[t]; for (int x : g[v]) if (dist[x] == -1) { dist[x] = dist[v] + 1; que.push_back(x); } } } vector vis(N); priority_queue> que; for (int i = 0 ; i < N ; i++) if (ssize(g[i]) == 1) { vis[i] = true; que.push({P[i], i}); } long long ans = 0LL; for (int t = 0 ; que.size() ; t++) { auto [p, v] = que.top(); que.pop(); if (t < dist[v]) { ans += p; for (int x : g[v]) if (!vis[x]) { vis[x] = true; que.push({P[x], x}); } } else t--; } cout << ans << '\n'; }